You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DiscordShardedClient.cs 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. using Discord.API;
  2. using Discord.Rest;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Threading.Tasks;
  8. using System.Threading;
  9. namespace Discord.WebSocket
  10. {
  11. public partial class DiscordShardedClient : BaseSocketClient, IDiscordClient
  12. {
  13. private readonly DiscordSocketConfig _baseConfig;
  14. private readonly SemaphoreSlim _connectionGroupLock;
  15. private int[] _shardIds;
  16. private Dictionary<int, int> _shardIdsToIndex;
  17. private DiscordSocketClient[] _shards;
  18. private int _totalShards;
  19. private bool _automaticShards;
  20. /// <summary> Gets the estimated round-trip latency, in milliseconds, to the gateway server. </summary>
  21. public override int Latency { get => GetLatency(); protected set { } }
  22. public override UserStatus Status { get => _shards[0].Status; protected set { } }
  23. public override Game? Game { get => _shards[0].Game; protected set { } }
  24. internal new DiscordSocketApiClient ApiClient => base.ApiClient as DiscordSocketApiClient;
  25. public override IReadOnlyCollection<SocketGuild> Guilds => GetGuilds().ToReadOnlyCollection(() => GetGuildCount());
  26. public override IReadOnlyCollection<ISocketPrivateChannel> PrivateChannels => GetPrivateChannels().ToReadOnlyCollection(() => GetPrivateChannelCount());
  27. public IReadOnlyCollection<DiscordSocketClient> Shards => _shards;
  28. public override IReadOnlyCollection<RestVoiceRegion> VoiceRegions => _shards[0].VoiceRegions;
  29. /// <summary> Creates a new REST/WebSocket discord client. </summary>
  30. public DiscordShardedClient() : this(null, new DiscordSocketConfig()) { }
  31. /// <summary> Creates a new REST/WebSocket discord client. </summary>
  32. public DiscordShardedClient(DiscordSocketConfig config) : this(null, config, CreateApiClient(config)) { }
  33. /// <summary> Creates a new REST/WebSocket discord client. </summary>
  34. public DiscordShardedClient(int[] ids) : this(ids, new DiscordSocketConfig()) { }
  35. /// <summary> Creates a new REST/WebSocket discord client. </summary>
  36. public DiscordShardedClient(int[] ids, DiscordSocketConfig config) : this(ids, config, CreateApiClient(config)) { }
  37. private DiscordShardedClient(int[] ids, DiscordSocketConfig config, API.DiscordSocketApiClient client)
  38. : base(config, client)
  39. {
  40. if (config.ShardId != null)
  41. throw new ArgumentException($"{nameof(config.ShardId)} must not be set.");
  42. if (ids != null && config.TotalShards == null)
  43. throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified.");
  44. _shardIdsToIndex = new Dictionary<int, int>();
  45. config.DisplayInitialLog = false;
  46. _baseConfig = config;
  47. _connectionGroupLock = new SemaphoreSlim(1, 1);
  48. if (config.TotalShards == null)
  49. _automaticShards = true;
  50. else
  51. {
  52. _totalShards = config.TotalShards.Value;
  53. _shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray();
  54. _shards = new DiscordSocketClient[_shardIds.Length];
  55. for (int i = 0; i < _shardIds.Length; i++)
  56. {
  57. _shardIdsToIndex.Add(_shardIds[i], i);
  58. var newConfig = config.Clone();
  59. newConfig.ShardId = _shardIds[i];
  60. _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null);
  61. RegisterEvents(_shards[i], i == 0);
  62. }
  63. }
  64. }
  65. private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
  66. => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent);
  67. internal override async Task OnLoginAsync(TokenType tokenType, string token)
  68. {
  69. if (_automaticShards)
  70. {
  71. var response = await ApiClient.GetBotGatewayAsync().ConfigureAwait(false);
  72. _shardIds = Enumerable.Range(0, response.Shards).ToArray();
  73. _totalShards = _shardIds.Length;
  74. _shards = new DiscordSocketClient[_shardIds.Length];
  75. for (int i = 0; i < _shardIds.Length; i++)
  76. {
  77. _shardIdsToIndex.Add(_shardIds[i], i);
  78. var newConfig = _baseConfig.Clone();
  79. newConfig.ShardId = _shardIds[i];
  80. newConfig.TotalShards = _totalShards;
  81. _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null);
  82. RegisterEvents(_shards[i], i == 0);
  83. }
  84. }
  85. //Assume threadsafe: already in a connection lock
  86. for (int i = 0; i < _shards.Length; i++)
  87. await _shards[i].LoginAsync(tokenType, token, false);
  88. }
  89. internal override async Task OnLogoutAsync()
  90. {
  91. //Assume threadsafe: already in a connection lock
  92. if (_shards != null)
  93. {
  94. for (int i = 0; i < _shards.Length; i++)
  95. await _shards[i].LogoutAsync();
  96. }
  97. CurrentUser = null;
  98. if (_automaticShards)
  99. {
  100. _shardIds = new int[0];
  101. _shardIdsToIndex.Clear();
  102. _totalShards = 0;
  103. _shards = null;
  104. }
  105. }
  106. /// <inheritdoc />
  107. public override async Task StartAsync()
  108. => await Task.WhenAll(_shards.Select(x => x.StartAsync())).ConfigureAwait(false);
  109. /// <inheritdoc />
  110. public override async Task StopAsync()
  111. => await Task.WhenAll(_shards.Select(x => x.StopAsync())).ConfigureAwait(false);
  112. public DiscordSocketClient GetShard(int id)
  113. {
  114. if (_shardIdsToIndex.TryGetValue(id, out id))
  115. return _shards[id];
  116. return null;
  117. }
  118. private int GetShardIdFor(ulong guildId)
  119. => (int)((guildId >> 22) % (uint)_totalShards);
  120. public int GetShardIdFor(IGuild guild)
  121. => GetShardIdFor(guild.Id);
  122. private DiscordSocketClient GetShardFor(ulong guildId)
  123. => GetShard(GetShardIdFor(guildId));
  124. public DiscordSocketClient GetShardFor(IGuild guild)
  125. => GetShardFor(guild.Id);
  126. /// <inheritdoc />
  127. public override async Task<RestApplication> GetApplicationInfoAsync(RequestOptions options = null)
  128. => await _shards[0].GetApplicationInfoAsync(options).ConfigureAwait(false);
  129. /// <inheritdoc />
  130. public override SocketGuild GetGuild(ulong id)
  131. => GetShardFor(id).GetGuild(id);
  132. /// <inheritdoc />
  133. public override SocketChannel GetChannel(ulong id)
  134. {
  135. for (int i = 0; i < _shards.Length; i++)
  136. {
  137. var channel = _shards[i].GetChannel(id);
  138. if (channel != null)
  139. return channel;
  140. }
  141. return null;
  142. }
  143. private IEnumerable<ISocketPrivateChannel> GetPrivateChannels()
  144. {
  145. for (int i = 0; i < _shards.Length; i++)
  146. {
  147. foreach (var channel in _shards[i].PrivateChannels)
  148. yield return channel;
  149. }
  150. }
  151. private int GetPrivateChannelCount()
  152. {
  153. int result = 0;
  154. for (int i = 0; i < _shards.Length; i++)
  155. result += _shards[i].PrivateChannels.Count;
  156. return result;
  157. }
  158. private IEnumerable<SocketGuild> GetGuilds()
  159. {
  160. for (int i = 0; i < _shards.Length; i++)
  161. {
  162. foreach (var guild in _shards[i].Guilds)
  163. yield return guild;
  164. }
  165. }
  166. private int GetGuildCount()
  167. {
  168. int result = 0;
  169. for (int i = 0; i < _shards.Length; i++)
  170. result += _shards[i].Guilds.Count;
  171. return result;
  172. }
  173. /// <inheritdoc />
  174. public override SocketUser GetUser(ulong id)
  175. {
  176. for (int i = 0; i < _shards.Length; i++)
  177. {
  178. var user = _shards[i].GetUser(id);
  179. if (user != null)
  180. return user;
  181. }
  182. return null;
  183. }
  184. /// <inheritdoc />
  185. public override SocketUser GetUser(string username, string discriminator)
  186. {
  187. for (int i = 0; i < _shards.Length; i++)
  188. {
  189. var user = _shards[i].GetUser(username, discriminator);
  190. if (user != null)
  191. return user;
  192. }
  193. return null;
  194. }
  195. /// <inheritdoc />
  196. public override RestVoiceRegion GetVoiceRegion(string id)
  197. => _shards[0].GetVoiceRegion(id);
  198. /// <summary> Downloads the users list for the provided guilds, if they don't have a complete list. </summary>
  199. public override async Task DownloadUsersAsync(IEnumerable<IGuild> guilds)
  200. {
  201. for (int i = 0; i < _shards.Length; i++)
  202. {
  203. int id = _shardIds[i];
  204. var arr = guilds.Where(x => GetShardIdFor(x) == id).ToArray();
  205. if (arr.Length > 0)
  206. await _shards[i].DownloadUsersAsync(arr).ConfigureAwait(false);
  207. }
  208. }
  209. private int GetLatency()
  210. {
  211. int total = 0;
  212. for (int i = 0; i < _shards.Length; i++)
  213. total += _shards[i].Latency;
  214. return (int)Math.Round(total / (double)_shards.Length);
  215. }
  216. public override async Task SetStatusAsync(UserStatus status)
  217. {
  218. for (int i = 0; i < _shards.Length; i++)
  219. await _shards[i].SetStatusAsync(status).ConfigureAwait(false);
  220. }
  221. public override async Task SetGameAsync(string name, string streamUrl = null, StreamType streamType = StreamType.NotStreaming)
  222. {
  223. for (int i = 0; i < _shards.Length; i++)
  224. await _shards[i].SetGameAsync(name, streamUrl, streamType).ConfigureAwait(false);
  225. }
  226. private void RegisterEvents(DiscordSocketClient client, bool isPrimary)
  227. {
  228. client.Log += (msg) => _logEvent.InvokeAsync(msg);
  229. client.LoggedOut += () =>
  230. {
  231. var state = LoginState;
  232. if (state == LoginState.LoggedIn || state == LoginState.LoggingIn)
  233. {
  234. //Should only happen if token is changed
  235. var _ = LogoutAsync(); //Signal the logout, fire and forget
  236. }
  237. return Task.Delay(0);
  238. };
  239. if (isPrimary)
  240. {
  241. client.Ready += () =>
  242. {
  243. CurrentUser = client.CurrentUser;
  244. return Task.Delay(0);
  245. };
  246. }
  247. client.Connected += () => _shardConnectedEvent.InvokeAsync(client);
  248. client.Disconnected += (exception) => _shardDisconnectedEvent.InvokeAsync(client, exception);
  249. client.Ready += () => _shardReadyEvent.InvokeAsync(client);
  250. client.ChannelCreated += (channel) => _channelCreatedEvent.InvokeAsync(channel);
  251. client.ChannelDestroyed += (channel) => _channelDestroyedEvent.InvokeAsync(channel);
  252. client.ChannelUpdated += (oldChannel, newChannel) => _channelUpdatedEvent.InvokeAsync(oldChannel, newChannel);
  253. client.MessageReceived += (msg) => _messageReceivedEvent.InvokeAsync(msg);
  254. client.MessageDeleted += (cache, channel) => _messageDeletedEvent.InvokeAsync(cache, channel);
  255. client.MessageUpdated += (oldMsg, newMsg, channel) => _messageUpdatedEvent.InvokeAsync(oldMsg, newMsg, channel);
  256. client.ReactionAdded += (cache, channel, reaction) => _reactionAddedEvent.InvokeAsync(cache, channel, reaction);
  257. client.ReactionRemoved += (cache, channel, reaction) => _reactionRemovedEvent.InvokeAsync(cache, channel, reaction);
  258. client.ReactionsCleared += (cache, channel) => _reactionsClearedEvent.InvokeAsync(cache, channel);
  259. client.RoleCreated += (role) => _roleCreatedEvent.InvokeAsync(role);
  260. client.RoleDeleted += (role) => _roleDeletedEvent.InvokeAsync(role);
  261. client.RoleUpdated += (oldRole, newRole) => _roleUpdatedEvent.InvokeAsync(oldRole, newRole);
  262. client.JoinedGuild += (guild) => _joinedGuildEvent.InvokeAsync(guild);
  263. client.LeftGuild += (guild) => _leftGuildEvent.InvokeAsync(guild);
  264. client.GuildAvailable += (guild) => _guildAvailableEvent.InvokeAsync(guild);
  265. client.GuildUnavailable += (guild) => _guildUnavailableEvent.InvokeAsync(guild);
  266. client.GuildMembersDownloaded += (guild) => _guildMembersDownloadedEvent.InvokeAsync(guild);
  267. client.GuildUpdated += (oldGuild, newGuild) => _guildUpdatedEvent.InvokeAsync(oldGuild, newGuild);
  268. client.UserJoined += (user) => _userJoinedEvent.InvokeAsync(user);
  269. client.UserLeft += (user) => _userLeftEvent.InvokeAsync(user);
  270. client.UserBanned += (user, guild) => _userBannedEvent.InvokeAsync(user, guild);
  271. client.UserUnbanned += (user, guild) => _userUnbannedEvent.InvokeAsync(user, guild);
  272. client.UserUpdated += (oldUser, newUser) => _userUpdatedEvent.InvokeAsync(oldUser, newUser);
  273. client.GuildMemberUpdated += (oldUser, newUser) => _guildMemberUpdatedEvent.InvokeAsync(oldUser, newUser);
  274. client.UserVoiceStateUpdated += (user, oldVoiceState, newVoiceState) => _userVoiceStateUpdatedEvent.InvokeAsync(user, oldVoiceState, newVoiceState);
  275. client.CurrentUserUpdated += (oldUser, newUser) => _selfUpdatedEvent.InvokeAsync(oldUser, newUser);
  276. client.UserIsTyping += (oldUser, newUser) => _userIsTypingEvent.InvokeAsync(oldUser, newUser);
  277. client.RecipientAdded += (user) => _recipientAddedEvent.InvokeAsync(user);
  278. client.RecipientRemoved += (user) => _recipientRemovedEvent.InvokeAsync(user);
  279. }
  280. //IDiscordClient
  281. async Task<IApplication> IDiscordClient.GetApplicationInfoAsync(RequestOptions options)
  282. => await GetApplicationInfoAsync().ConfigureAwait(false);
  283. Task<IChannel> IDiscordClient.GetChannelAsync(ulong id, CacheMode mode, RequestOptions options)
  284. => Task.FromResult<IChannel>(GetChannel(id));
  285. Task<IReadOnlyCollection<IPrivateChannel>> IDiscordClient.GetPrivateChannelsAsync(CacheMode mode, RequestOptions options)
  286. => Task.FromResult<IReadOnlyCollection<IPrivateChannel>>(PrivateChannels);
  287. async Task<IReadOnlyCollection<IConnection>> IDiscordClient.GetConnectionsAsync(RequestOptions options)
  288. => await GetConnectionsAsync().ConfigureAwait(false);
  289. async Task<IInvite> IDiscordClient.GetInviteAsync(string inviteId, RequestOptions options)
  290. => await GetInviteAsync(inviteId).ConfigureAwait(false);
  291. Task<IGuild> IDiscordClient.GetGuildAsync(ulong id, CacheMode mode, RequestOptions options)
  292. => Task.FromResult<IGuild>(GetGuild(id));
  293. Task<IReadOnlyCollection<IGuild>> IDiscordClient.GetGuildsAsync(CacheMode mode, RequestOptions options)
  294. => Task.FromResult<IReadOnlyCollection<IGuild>>(Guilds);
  295. async Task<IGuild> IDiscordClient.CreateGuildAsync(string name, IVoiceRegion region, Stream jpegIcon, RequestOptions options)
  296. => await CreateGuildAsync(name, region, jpegIcon).ConfigureAwait(false);
  297. Task<IUser> IDiscordClient.GetUserAsync(ulong id, CacheMode mode, RequestOptions options)
  298. => Task.FromResult<IUser>(GetUser(id));
  299. Task<IUser> IDiscordClient.GetUserAsync(string username, string discriminator, RequestOptions options)
  300. => Task.FromResult<IUser>(GetUser(username, discriminator));
  301. Task<IReadOnlyCollection<IVoiceRegion>> IDiscordClient.GetVoiceRegionsAsync(RequestOptions options)
  302. => Task.FromResult<IReadOnlyCollection<IVoiceRegion>>(VoiceRegions);
  303. Task<IVoiceRegion> IDiscordClient.GetVoiceRegionAsync(string id, RequestOptions options)
  304. => Task.FromResult<IVoiceRegion>(GetVoiceRegion(id));
  305. }
  306. }