You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DiscordShardedClient.cs 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. using Discord.API;
  2. using Discord.Rest;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Threading.Tasks;
  8. using System.Threading;
  9. namespace Discord.WebSocket
  10. {
  11. public partial class DiscordShardedClient : BaseSocketClient, IDiscordClient
  12. {
  13. private readonly DiscordSocketConfig _baseConfig;
  14. private readonly SemaphoreSlim _connectionGroupLock;
  15. private readonly Dictionary<int, int> _shardIdsToIndex;
  16. private readonly bool _automaticShards;
  17. private int[] _shardIds;
  18. private DiscordSocketClient[] _shards;
  19. private int _totalShards;
  20. /// <inheritdoc />
  21. public override int Latency { get => GetLatency(); protected set { } }
  22. /// <inheritdoc />
  23. public override UserStatus Status { get => _shards[0].Status; protected set { } }
  24. /// <inheritdoc />
  25. public override IActivity Activity { get => _shards[0].Activity; protected set { } }
  26. internal new DiscordSocketApiClient ApiClient => base.ApiClient as DiscordSocketApiClient;
  27. /// <inheritdoc />
  28. public override IReadOnlyCollection<SocketGuild> Guilds => GetGuilds().ToReadOnlyCollection(GetGuildCount);
  29. /// <inheritdoc />
  30. public override IReadOnlyCollection<ISocketPrivateChannel> PrivateChannels => GetPrivateChannels().ToReadOnlyCollection(GetPrivateChannelCount);
  31. public IReadOnlyCollection<DiscordSocketClient> Shards => _shards;
  32. /// <inheritdoc />
  33. public override IReadOnlyCollection<RestVoiceRegion> VoiceRegions => _shards[0].VoiceRegions;
  34. /// <summary> Creates a new REST/WebSocket Discord client. </summary>
  35. public DiscordShardedClient() : this(null, new DiscordSocketConfig()) { }
  36. /// <summary> Creates a new REST/WebSocket Discord client. </summary>
  37. public DiscordShardedClient(DiscordSocketConfig config) : this(null, config, CreateApiClient(config)) { }
  38. /// <summary> Creates a new REST/WebSocket Discord client. </summary>
  39. public DiscordShardedClient(int[] ids) : this(ids, new DiscordSocketConfig()) { }
  40. /// <summary> Creates a new REST/WebSocket Discord client. </summary>
  41. public DiscordShardedClient(int[] ids, DiscordSocketConfig config) : this(ids, config, CreateApiClient(config)) { }
  42. private DiscordShardedClient(int[] ids, DiscordSocketConfig config, API.DiscordSocketApiClient client)
  43. : base(config, client)
  44. {
  45. if (config.ShardId != null)
  46. throw new ArgumentException($"{nameof(config.ShardId)} must not be set.");
  47. if (ids != null && config.TotalShards == null)
  48. throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified.");
  49. _shardIdsToIndex = new Dictionary<int, int>();
  50. config.DisplayInitialLog = false;
  51. _baseConfig = config;
  52. _connectionGroupLock = new SemaphoreSlim(1, 1);
  53. if (config.TotalShards == null)
  54. _automaticShards = true;
  55. else
  56. {
  57. _totalShards = config.TotalShards.Value;
  58. _shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray();
  59. _shards = new DiscordSocketClient[_shardIds.Length];
  60. for (int i = 0; i < _shardIds.Length; i++)
  61. {
  62. _shardIdsToIndex.Add(_shardIds[i], i);
  63. var newConfig = config.Clone();
  64. newConfig.ShardId = _shardIds[i];
  65. _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null);
  66. RegisterEvents(_shards[i], i == 0);
  67. }
  68. }
  69. }
  70. private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
  71. => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent);
  72. internal override async Task OnLoginAsync(TokenType tokenType, string token)
  73. {
  74. if (_automaticShards)
  75. {
  76. var shardCount = await GetRecommendedShardCountAsync().ConfigureAwait(false);
  77. _shardIds = Enumerable.Range(0, shardCount).ToArray();
  78. _totalShards = _shardIds.Length;
  79. _shards = new DiscordSocketClient[_shardIds.Length];
  80. for (int i = 0; i < _shardIds.Length; i++)
  81. {
  82. _shardIdsToIndex.Add(_shardIds[i], i);
  83. var newConfig = _baseConfig.Clone();
  84. newConfig.ShardId = _shardIds[i];
  85. newConfig.TotalShards = _totalShards;
  86. _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null);
  87. RegisterEvents(_shards[i], i == 0);
  88. }
  89. }
  90. //Assume thread safe: already in a connection lock
  91. for (int i = 0; i < _shards.Length; i++)
  92. await _shards[i].LoginAsync(tokenType, token);
  93. }
  94. internal override async Task OnLogoutAsync()
  95. {
  96. //Assume thread safe: already in a connection lock
  97. if (_shards != null)
  98. {
  99. for (int i = 0; i < _shards.Length; i++)
  100. await _shards[i].LogoutAsync();
  101. }
  102. CurrentUser = null;
  103. if (_automaticShards)
  104. {
  105. _shardIds = new int[0];
  106. _shardIdsToIndex.Clear();
  107. _totalShards = 0;
  108. _shards = null;
  109. }
  110. }
  111. /// <inheritdoc />
  112. public override async Task StartAsync()
  113. => await Task.WhenAll(_shards.Select(x => x.StartAsync())).ConfigureAwait(false);
  114. /// <inheritdoc />
  115. public override async Task StopAsync()
  116. => await Task.WhenAll(_shards.Select(x => x.StopAsync())).ConfigureAwait(false);
  117. public DiscordSocketClient GetShard(int id)
  118. {
  119. if (_shardIdsToIndex.TryGetValue(id, out id))
  120. return _shards[id];
  121. return null;
  122. }
  123. private int GetShardIdFor(ulong guildId)
  124. => (int)((guildId >> 22) % (uint)_totalShards);
  125. public int GetShardIdFor(IGuild guild)
  126. => GetShardIdFor(guild?.Id ?? 0);
  127. private DiscordSocketClient GetShardFor(ulong guildId)
  128. => GetShard(GetShardIdFor(guildId));
  129. public DiscordSocketClient GetShardFor(IGuild guild)
  130. => GetShardFor(guild?.Id ?? 0);
  131. /// <inheritdoc />
  132. public override async Task<RestApplication> GetApplicationInfoAsync(RequestOptions options = null)
  133. => await _shards[0].GetApplicationInfoAsync(options).ConfigureAwait(false);
  134. /// <inheritdoc />
  135. public override SocketGuild GetGuild(ulong id)
  136. => GetShardFor(id).GetGuild(id);
  137. /// <inheritdoc />
  138. public override SocketChannel GetChannel(ulong id)
  139. {
  140. for (int i = 0; i < _shards.Length; i++)
  141. {
  142. var channel = _shards[i].GetChannel(id);
  143. if (channel != null)
  144. return channel;
  145. }
  146. return null;
  147. }
  148. private IEnumerable<ISocketPrivateChannel> GetPrivateChannels()
  149. {
  150. for (int i = 0; i < _shards.Length; i++)
  151. {
  152. foreach (var channel in _shards[i].PrivateChannels)
  153. yield return channel;
  154. }
  155. }
  156. private int GetPrivateChannelCount()
  157. {
  158. int result = 0;
  159. for (int i = 0; i < _shards.Length; i++)
  160. result += _shards[i].PrivateChannels.Count;
  161. return result;
  162. }
  163. private IEnumerable<SocketGuild> GetGuilds()
  164. {
  165. for (int i = 0; i < _shards.Length; i++)
  166. {
  167. foreach (var guild in _shards[i].Guilds)
  168. yield return guild;
  169. }
  170. }
  171. private int GetGuildCount()
  172. {
  173. int result = 0;
  174. for (int i = 0; i < _shards.Length; i++)
  175. result += _shards[i].Guilds.Count;
  176. return result;
  177. }
  178. /// <inheritdoc />
  179. public override SocketUser GetUser(ulong id)
  180. {
  181. for (int i = 0; i < _shards.Length; i++)
  182. {
  183. var user = _shards[i].GetUser(id);
  184. if (user != null)
  185. return user;
  186. }
  187. return null;
  188. }
  189. /// <inheritdoc />
  190. public override SocketUser GetUser(string username, string discriminator)
  191. {
  192. for (int i = 0; i < _shards.Length; i++)
  193. {
  194. var user = _shards[i].GetUser(username, discriminator);
  195. if (user != null)
  196. return user;
  197. }
  198. return null;
  199. }
  200. /// <inheritdoc />
  201. public override RestVoiceRegion GetVoiceRegion(string id)
  202. => _shards[0].GetVoiceRegion(id);
  203. /// <inheritdoc />
  204. /// <exception cref="ArgumentNullException"><paramref name="guilds"/> is <see langword="null"/></exception>
  205. public override async Task DownloadUsersAsync(IEnumerable<IGuild> guilds)
  206. {
  207. if (guilds == null) throw new ArgumentNullException(nameof(guilds));
  208. for (int i = 0; i < _shards.Length; i++)
  209. {
  210. int id = _shardIds[i];
  211. var arr = guilds.Where(x => GetShardIdFor(x) == id).ToArray();
  212. if (arr.Length > 0)
  213. await _shards[i].DownloadUsersAsync(arr).ConfigureAwait(false);
  214. }
  215. }
  216. private int GetLatency()
  217. {
  218. int total = 0;
  219. for (int i = 0; i < _shards.Length; i++)
  220. total += _shards[i].Latency;
  221. return (int)Math.Round(total / (double)_shards.Length);
  222. }
  223. /// <inheritdoc />
  224. public override async Task SetStatusAsync(UserStatus status)
  225. {
  226. for (int i = 0; i < _shards.Length; i++)
  227. await _shards[i].SetStatusAsync(status).ConfigureAwait(false);
  228. }
  229. /// <inheritdoc />
  230. public override async Task SetGameAsync(string name, string streamUrl = null, ActivityType type = ActivityType.Playing)
  231. {
  232. IActivity activity = null;
  233. if (!string.IsNullOrEmpty(streamUrl))
  234. activity = new StreamingGame(name, streamUrl);
  235. else if (!string.IsNullOrEmpty(name))
  236. activity = new Game(name, type);
  237. await SetActivityAsync(activity).ConfigureAwait(false);
  238. }
  239. /// <inheritdoc />
  240. public override async Task SetActivityAsync(IActivity activity)
  241. {
  242. for (int i = 0; i < _shards.Length; i++)
  243. await _shards[i].SetActivityAsync(activity).ConfigureAwait(false);
  244. }
  245. private void RegisterEvents(DiscordSocketClient client, bool isPrimary)
  246. {
  247. client.Log += (msg) => _logEvent.InvokeAsync(msg);
  248. client.LoggedOut += () =>
  249. {
  250. var state = LoginState;
  251. if (state == LoginState.LoggedIn || state == LoginState.LoggingIn)
  252. {
  253. //Should only happen if token is changed
  254. var _ = LogoutAsync(); //Signal the logout, fire and forget
  255. }
  256. return Task.Delay(0);
  257. };
  258. if (isPrimary)
  259. {
  260. client.Ready += () =>
  261. {
  262. CurrentUser = client.CurrentUser;
  263. return Task.Delay(0);
  264. };
  265. }
  266. client.Connected += () => _shardConnectedEvent.InvokeAsync(client);
  267. client.Disconnected += (exception) => _shardDisconnectedEvent.InvokeAsync(exception, client);
  268. client.Ready += () => _shardReadyEvent.InvokeAsync(client);
  269. client.LatencyUpdated += (oldLatency, newLatency) => _shardLatencyUpdatedEvent.InvokeAsync(oldLatency, newLatency, client);
  270. client.ChannelCreated += (channel) => _channelCreatedEvent.InvokeAsync(channel);
  271. client.ChannelDestroyed += (channel) => _channelDestroyedEvent.InvokeAsync(channel);
  272. client.ChannelUpdated += (oldChannel, newChannel) => _channelUpdatedEvent.InvokeAsync(oldChannel, newChannel);
  273. client.MessageReceived += (msg) => _messageReceivedEvent.InvokeAsync(msg);
  274. client.MessageDeleted += (cache, channel) => _messageDeletedEvent.InvokeAsync(cache, channel);
  275. client.MessageUpdated += (oldMsg, newMsg, channel) => _messageUpdatedEvent.InvokeAsync(oldMsg, newMsg, channel);
  276. client.ReactionAdded += (cache, channel, reaction) => _reactionAddedEvent.InvokeAsync(cache, channel, reaction);
  277. client.ReactionRemoved += (cache, channel, reaction) => _reactionRemovedEvent.InvokeAsync(cache, channel, reaction);
  278. client.ReactionsCleared += (cache, channel) => _reactionsClearedEvent.InvokeAsync(cache, channel);
  279. client.RoleCreated += (role) => _roleCreatedEvent.InvokeAsync(role);
  280. client.RoleDeleted += (role) => _roleDeletedEvent.InvokeAsync(role);
  281. client.RoleUpdated += (oldRole, newRole) => _roleUpdatedEvent.InvokeAsync(oldRole, newRole);
  282. client.JoinedGuild += (guild) => _joinedGuildEvent.InvokeAsync(guild);
  283. client.LeftGuild += (guild) => _leftGuildEvent.InvokeAsync(guild);
  284. client.GuildAvailable += (guild) => _guildAvailableEvent.InvokeAsync(guild);
  285. client.GuildUnavailable += (guild) => _guildUnavailableEvent.InvokeAsync(guild);
  286. client.GuildMembersDownloaded += (guild) => _guildMembersDownloadedEvent.InvokeAsync(guild);
  287. client.GuildUpdated += (oldGuild, newGuild) => _guildUpdatedEvent.InvokeAsync(oldGuild, newGuild);
  288. client.UserJoined += (user) => _userJoinedEvent.InvokeAsync(user);
  289. client.UserLeft += (user) => _userLeftEvent.InvokeAsync(user);
  290. client.UserBanned += (user, guild) => _userBannedEvent.InvokeAsync(user, guild);
  291. client.UserUnbanned += (user, guild) => _userUnbannedEvent.InvokeAsync(user, guild);
  292. client.UserUpdated += (oldUser, newUser) => _userUpdatedEvent.InvokeAsync(oldUser, newUser);
  293. client.GuildMemberUpdated += (oldUser, newUser) => _guildMemberUpdatedEvent.InvokeAsync(oldUser, newUser);
  294. client.UserVoiceStateUpdated += (user, oldVoiceState, newVoiceState) => _userVoiceStateUpdatedEvent.InvokeAsync(user, oldVoiceState, newVoiceState);
  295. client.VoiceServerUpdated += (server) => _voiceServerUpdatedEvent.InvokeAsync(server);
  296. client.CurrentUserUpdated += (oldUser, newUser) => _selfUpdatedEvent.InvokeAsync(oldUser, newUser);
  297. client.UserIsTyping += (oldUser, newUser) => _userIsTypingEvent.InvokeAsync(oldUser, newUser);
  298. client.RecipientAdded += (user) => _recipientAddedEvent.InvokeAsync(user);
  299. client.RecipientRemoved += (user) => _recipientRemovedEvent.InvokeAsync(user);
  300. }
  301. //IDiscordClient
  302. /// <inheritdoc />
  303. async Task<IApplication> IDiscordClient.GetApplicationInfoAsync(RequestOptions options)
  304. => await GetApplicationInfoAsync().ConfigureAwait(false);
  305. /// <inheritdoc />
  306. Task<IChannel> IDiscordClient.GetChannelAsync(ulong id, CacheMode mode, RequestOptions options)
  307. => Task.FromResult<IChannel>(GetChannel(id));
  308. /// <inheritdoc />
  309. Task<IReadOnlyCollection<IPrivateChannel>> IDiscordClient.GetPrivateChannelsAsync(CacheMode mode, RequestOptions options)
  310. => Task.FromResult<IReadOnlyCollection<IPrivateChannel>>(PrivateChannels);
  311. /// <inheritdoc />
  312. async Task<IReadOnlyCollection<IConnection>> IDiscordClient.GetConnectionsAsync(RequestOptions options)
  313. => await GetConnectionsAsync().ConfigureAwait(false);
  314. /// <inheritdoc />
  315. async Task<IInvite> IDiscordClient.GetInviteAsync(string inviteId, RequestOptions options)
  316. => await GetInviteAsync(inviteId, options).ConfigureAwait(false);
  317. /// <inheritdoc />
  318. Task<IGuild> IDiscordClient.GetGuildAsync(ulong id, CacheMode mode, RequestOptions options)
  319. => Task.FromResult<IGuild>(GetGuild(id));
  320. /// <inheritdoc />
  321. Task<IReadOnlyCollection<IGuild>> IDiscordClient.GetGuildsAsync(CacheMode mode, RequestOptions options)
  322. => Task.FromResult<IReadOnlyCollection<IGuild>>(Guilds);
  323. /// <inheritdoc />
  324. async Task<IGuild> IDiscordClient.CreateGuildAsync(string name, IVoiceRegion region, Stream jpegIcon, RequestOptions options)
  325. => await CreateGuildAsync(name, region, jpegIcon).ConfigureAwait(false);
  326. /// <inheritdoc />
  327. Task<IUser> IDiscordClient.GetUserAsync(ulong id, CacheMode mode, RequestOptions options)
  328. => Task.FromResult<IUser>(GetUser(id));
  329. /// <inheritdoc />
  330. Task<IUser> IDiscordClient.GetUserAsync(string username, string discriminator, RequestOptions options)
  331. => Task.FromResult<IUser>(GetUser(username, discriminator));
  332. /// <inheritdoc />
  333. Task<IReadOnlyCollection<IVoiceRegion>> IDiscordClient.GetVoiceRegionsAsync(RequestOptions options)
  334. => Task.FromResult<IReadOnlyCollection<IVoiceRegion>>(VoiceRegions);
  335. /// <inheritdoc />
  336. Task<IVoiceRegion> IDiscordClient.GetVoiceRegionAsync(string id, RequestOptions options)
  337. => Task.FromResult<IVoiceRegion>(GetVoiceRegion(id));
  338. }
  339. }