You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DiscordShardedClient.cs 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. using Discord.API;
  2. using Discord.Rest;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Threading.Tasks;
  8. using System.Threading;
  9. namespace Discord.WebSocket
  10. {
  11. public partial class DiscordShardedClient : BaseDiscordClient, IDiscordClient
  12. {
  13. private readonly DiscordSocketConfig _baseConfig;
  14. private readonly SemaphoreSlim _connectionGroupLock;
  15. private int[] _shardIds;
  16. private Dictionary<int, int> _shardIdsToIndex;
  17. private DiscordSocketClient[] _shards;
  18. private int _totalShards;
  19. private bool _automaticShards;
  20. /// <summary> Gets the estimated round-trip latency, in milliseconds, to the gateway server. </summary>
  21. public int Latency => GetLatency();
  22. public UserStatus Status => _shards[0].Status;
  23. public Game? Game => _shards[0].Game;
  24. internal new DiscordSocketApiClient ApiClient => base.ApiClient as DiscordSocketApiClient;
  25. public new SocketSelfUser CurrentUser { get { return base.CurrentUser as SocketSelfUser; } private set { base.CurrentUser = value; } }
  26. public IReadOnlyCollection<SocketGuild> Guilds => GetGuilds().ToReadOnlyCollection(() => GetGuildCount());
  27. public IReadOnlyCollection<ISocketPrivateChannel> PrivateChannels => GetPrivateChannels().ToReadOnlyCollection(() => GetPrivateChannelCount());
  28. public IReadOnlyCollection<DiscordSocketClient> Shards => _shards;
  29. public IReadOnlyCollection<RestVoiceRegion> VoiceRegions => _shards[0].VoiceRegions;
  30. /// <summary> Creates a new REST/WebSocket discord client. </summary>
  31. public DiscordShardedClient() : this(null, new DiscordSocketConfig()) { }
  32. /// <summary> Creates a new REST/WebSocket discord client. </summary>
  33. public DiscordShardedClient(DiscordSocketConfig config) : this(null, config, CreateApiClient(config)) { }
  34. /// <summary> Creates a new REST/WebSocket discord client. </summary>
  35. public DiscordShardedClient(int[] ids) : this(ids, new DiscordSocketConfig()) { }
  36. /// <summary> Creates a new REST/WebSocket discord client. </summary>
  37. public DiscordShardedClient(int[] ids, DiscordSocketConfig config) : this(ids, config, CreateApiClient(config)) { }
  38. private DiscordShardedClient(int[] ids, DiscordSocketConfig config, API.DiscordSocketApiClient client)
  39. : base(config, client)
  40. {
  41. if (config.ShardId != null)
  42. throw new ArgumentException($"{nameof(config.ShardId)} must not be set.");
  43. if (ids != null && config.TotalShards == null)
  44. throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified.");
  45. _shardIdsToIndex = new Dictionary<int, int>();
  46. config.DisplayInitialLog = false;
  47. _baseConfig = config;
  48. _connectionGroupLock = new SemaphoreSlim(1, 1);
  49. if (config.TotalShards == null)
  50. _automaticShards = true;
  51. else
  52. {
  53. _totalShards = config.TotalShards.Value;
  54. _shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray();
  55. _shards = new DiscordSocketClient[_shardIds.Length];
  56. for (int i = 0; i < _shardIds.Length; i++)
  57. {
  58. _shardIdsToIndex.Add(_shardIds[i], i);
  59. var newConfig = config.Clone();
  60. newConfig.ShardId = _shardIds[i];
  61. _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null);
  62. RegisterEvents(_shards[i], i == 0);
  63. }
  64. }
  65. }
  66. private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
  67. => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent);
  68. internal override async Task OnLoginAsync(TokenType tokenType, string token)
  69. {
  70. if (_automaticShards)
  71. {
  72. var response = await ApiClient.GetBotGatewayAsync().ConfigureAwait(false);
  73. _shardIds = Enumerable.Range(0, response.Shards).ToArray();
  74. _totalShards = _shardIds.Length;
  75. _shards = new DiscordSocketClient[_shardIds.Length];
  76. for (int i = 0; i < _shardIds.Length; i++)
  77. {
  78. _shardIdsToIndex.Add(_shardIds[i], i);
  79. var newConfig = _baseConfig.Clone();
  80. newConfig.ShardId = _shardIds[i];
  81. newConfig.TotalShards = _totalShards;
  82. _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null);
  83. RegisterEvents(_shards[i], i == 0);
  84. }
  85. }
  86. //Assume threadsafe: already in a connection lock
  87. for (int i = 0; i < _shards.Length; i++)
  88. await _shards[i].LoginAsync(tokenType, token, false);
  89. }
  90. internal override async Task OnLogoutAsync()
  91. {
  92. //Assume threadsafe: already in a connection lock
  93. if (_shards != null)
  94. {
  95. for (int i = 0; i < _shards.Length; i++)
  96. await _shards[i].LogoutAsync();
  97. }
  98. CurrentUser = null;
  99. if (_automaticShards)
  100. {
  101. _shardIds = new int[0];
  102. _shardIdsToIndex.Clear();
  103. _totalShards = 0;
  104. _shards = null;
  105. }
  106. }
  107. /// <inheritdoc />
  108. public async Task StartAsync()
  109. {
  110. await Task.WhenAll(_shards.Select(x => x.StartAsync())).ConfigureAwait(false);
  111. }
  112. /// <inheritdoc />
  113. public async Task StopAsync()
  114. {
  115. await Task.WhenAll(_shards.Select(x => x.StopAsync())).ConfigureAwait(false);
  116. }
  117. public DiscordSocketClient GetShard(int id)
  118. {
  119. if (_shardIdsToIndex.TryGetValue(id, out id))
  120. return _shards[id];
  121. return null;
  122. }
  123. private int GetShardIdFor(ulong guildId)
  124. => (int)((guildId >> 22) % (uint)_totalShards);
  125. public int GetShardIdFor(IGuild guild)
  126. => GetShardIdFor(guild.Id);
  127. private DiscordSocketClient GetShardFor(ulong guildId)
  128. => GetShard(GetShardIdFor(guildId));
  129. public DiscordSocketClient GetShardFor(IGuild guild)
  130. => GetShardFor(guild.Id);
  131. /// <inheritdoc />
  132. public async Task<RestApplication> GetApplicationInfoAsync()
  133. => await _shards[0].GetApplicationInfoAsync().ConfigureAwait(false);
  134. /// <inheritdoc />
  135. public SocketGuild GetGuild(ulong id) => GetShardFor(id).GetGuild(id);
  136. /// <inheritdoc />
  137. public Task<RestGuild> CreateGuildAsync(string name, IVoiceRegion region, Stream jpegIcon = null)
  138. => ClientHelper.CreateGuildAsync(this, name, region, jpegIcon, new RequestOptions());
  139. /// <inheritdoc />
  140. public SocketChannel GetChannel(ulong id)
  141. {
  142. for (int i = 0; i < _shards.Length; i++)
  143. {
  144. var channel = _shards[i].GetChannel(id);
  145. if (channel != null)
  146. return channel;
  147. }
  148. return null;
  149. }
  150. private IEnumerable<ISocketPrivateChannel> GetPrivateChannels()
  151. {
  152. for (int i = 0; i < _shards.Length; i++)
  153. {
  154. foreach (var channel in _shards[i].PrivateChannels)
  155. yield return channel;
  156. }
  157. }
  158. private int GetPrivateChannelCount()
  159. {
  160. int result = 0;
  161. for (int i = 0; i < _shards.Length; i++)
  162. result += _shards[i].PrivateChannels.Count;
  163. return result;
  164. }
  165. /// <inheritdoc />
  166. public Task<IReadOnlyCollection<RestConnection>> GetConnectionsAsync()
  167. => ClientHelper.GetConnectionsAsync(this, new RequestOptions());
  168. private IEnumerable<SocketGuild> GetGuilds()
  169. {
  170. for (int i = 0; i < _shards.Length; i++)
  171. {
  172. foreach (var guild in _shards[i].Guilds)
  173. yield return guild;
  174. }
  175. }
  176. private int GetGuildCount()
  177. {
  178. int result = 0;
  179. for (int i = 0; i < _shards.Length; i++)
  180. result += _shards[i].Guilds.Count;
  181. return result;
  182. }
  183. /// <inheritdoc />
  184. public Task<RestInvite> GetInviteAsync(string inviteId)
  185. => ClientHelper.GetInviteAsync(this, inviteId, new RequestOptions());
  186. /// <inheritdoc />
  187. public SocketUser GetUser(ulong id)
  188. {
  189. for (int i = 0; i < _shards.Length; i++)
  190. {
  191. var user = _shards[i].GetUser(id);
  192. if (user != null)
  193. return user;
  194. }
  195. return null;
  196. }
  197. /// <inheritdoc />
  198. public SocketUser GetUser(string username, string discriminator)
  199. {
  200. for (int i = 0; i < _shards.Length; i++)
  201. {
  202. var user = _shards[i].GetUser(username, discriminator);
  203. if (user != null)
  204. return user;
  205. }
  206. return null;
  207. }
  208. /// <inheritdoc />
  209. public RestVoiceRegion GetVoiceRegion(string id)
  210. => _shards[0].GetVoiceRegion(id);
  211. /// <summary> Downloads the users list for the provided guilds, if they don't have a complete list. </summary>
  212. public async Task DownloadUsersAsync(IEnumerable<SocketGuild> guilds)
  213. {
  214. for (int i = 0; i < _shards.Length; i++)
  215. {
  216. int id = _shardIds[i];
  217. var arr = guilds.Where(x => GetShardIdFor(x) == id).ToArray();
  218. if (arr.Length > 0)
  219. await _shards[i].DownloadUsersAsync(arr).ConfigureAwait(false);
  220. }
  221. }
  222. private int GetLatency()
  223. {
  224. int total = 0;
  225. for (int i = 0; i < _shards.Length; i++)
  226. total += _shards[i].Latency;
  227. return (int)Math.Round(total / (double)_shards.Length);
  228. }
  229. public async Task SetStatusAsync(UserStatus status)
  230. {
  231. for (int i = 0; i < _shards.Length; i++)
  232. await _shards[i].SetStatusAsync(status).ConfigureAwait(false);
  233. }
  234. public async Task SetGameAsync(string name, string streamUrl = null, StreamType streamType = StreamType.NotStreaming)
  235. {
  236. for (int i = 0; i < _shards.Length; i++)
  237. await _shards[i].SetGameAsync(name, streamUrl, streamType).ConfigureAwait(false);
  238. }
  239. private void RegisterEvents(DiscordSocketClient client, bool isPrimary)
  240. {
  241. client.Log += (msg) => _logEvent.InvokeAsync(msg);
  242. client.LoggedOut += () =>
  243. {
  244. var state = LoginState;
  245. if (state == LoginState.LoggedIn || state == LoginState.LoggingIn)
  246. {
  247. //Should only happen if token is changed
  248. var _ = LogoutAsync(); //Signal the logout, fire and forget
  249. }
  250. return Task.Delay(0);
  251. };
  252. if (isPrimary)
  253. {
  254. client.Ready += () =>
  255. {
  256. CurrentUser = client.CurrentUser;
  257. return Task.Delay(0);
  258. };
  259. }
  260. client.ChannelCreated += (channel) => _channelCreatedEvent.InvokeAsync(channel);
  261. client.ChannelDestroyed += (channel) => _channelDestroyedEvent.InvokeAsync(channel);
  262. client.ChannelUpdated += (oldChannel, newChannel) => _channelUpdatedEvent.InvokeAsync(oldChannel, newChannel);
  263. client.MessageReceived += (msg) => _messageReceivedEvent.InvokeAsync(msg);
  264. client.MessageDeleted += (cache, channel) => _messageDeletedEvent.InvokeAsync(cache, channel);
  265. client.MessageUpdated += (oldMsg, newMsg, channel) => _messageUpdatedEvent.InvokeAsync(oldMsg, newMsg, channel);
  266. client.ReactionAdded += (cache, channel, reaction) => _reactionAddedEvent.InvokeAsync(cache, channel, reaction);
  267. client.ReactionRemoved += (cache, channel, reaction) => _reactionRemovedEvent.InvokeAsync(cache, channel, reaction);
  268. client.ReactionsCleared += (cache, channel) => _reactionsClearedEvent.InvokeAsync(cache, channel);
  269. client.RoleCreated += (role) => _roleCreatedEvent.InvokeAsync(role);
  270. client.RoleDeleted += (role) => _roleDeletedEvent.InvokeAsync(role);
  271. client.RoleUpdated += (oldRole, newRole) => _roleUpdatedEvent.InvokeAsync(oldRole, newRole);
  272. client.JoinedGuild += (guild) => _joinedGuildEvent.InvokeAsync(guild);
  273. client.LeftGuild += (guild) => _leftGuildEvent.InvokeAsync(guild);
  274. client.GuildAvailable += (guild) => _guildAvailableEvent.InvokeAsync(guild);
  275. client.GuildUnavailable += (guild) => _guildUnavailableEvent.InvokeAsync(guild);
  276. client.GuildMembersDownloaded += (guild) => _guildMembersDownloadedEvent.InvokeAsync(guild);
  277. client.GuildUpdated += (oldGuild, newGuild) => _guildUpdatedEvent.InvokeAsync(oldGuild, newGuild);
  278. client.UserJoined += (user) => _userJoinedEvent.InvokeAsync(user);
  279. client.UserLeft += (user) => _userLeftEvent.InvokeAsync(user);
  280. client.UserBanned += (user, guild) => _userBannedEvent.InvokeAsync(user, guild);
  281. client.UserUnbanned += (user, guild) => _userUnbannedEvent.InvokeAsync(user, guild);
  282. client.UserUpdated += (oldUser, newUser) => _userUpdatedEvent.InvokeAsync(oldUser, newUser);
  283. client.GuildMemberUpdated += (oldUser, newUser) => _guildMemberUpdatedEvent.InvokeAsync(oldUser, newUser);
  284. client.UserVoiceStateUpdated += (user, oldVoiceState, newVoiceState) => _userVoiceStateUpdatedEvent.InvokeAsync(user, oldVoiceState, newVoiceState);
  285. client.CurrentUserUpdated += (oldUser, newUser) => _selfUpdatedEvent.InvokeAsync(oldUser, newUser);
  286. client.UserIsTyping += (oldUser, newUser) => _userIsTypingEvent.InvokeAsync(oldUser, newUser);
  287. client.RecipientAdded += (user) => _recipientAddedEvent.InvokeAsync(user);
  288. client.RecipientRemoved += (user) => _recipientRemovedEvent.InvokeAsync(user);
  289. }
  290. //IDiscordClient
  291. async Task<IApplication> IDiscordClient.GetApplicationInfoAsync(RequestOptions options)
  292. => await GetApplicationInfoAsync().ConfigureAwait(false);
  293. Task<IChannel> IDiscordClient.GetChannelAsync(ulong id, CacheMode mode, RequestOptions options)
  294. => Task.FromResult<IChannel>(GetChannel(id));
  295. Task<IReadOnlyCollection<IPrivateChannel>> IDiscordClient.GetPrivateChannelsAsync(CacheMode mode, RequestOptions options)
  296. => Task.FromResult<IReadOnlyCollection<IPrivateChannel>>(PrivateChannels);
  297. async Task<IReadOnlyCollection<IConnection>> IDiscordClient.GetConnectionsAsync(RequestOptions options)
  298. => await GetConnectionsAsync().ConfigureAwait(false);
  299. async Task<IInvite> IDiscordClient.GetInviteAsync(string inviteId, RequestOptions options)
  300. => await GetInviteAsync(inviteId).ConfigureAwait(false);
  301. Task<IGuild> IDiscordClient.GetGuildAsync(ulong id, CacheMode mode, RequestOptions options)
  302. => Task.FromResult<IGuild>(GetGuild(id));
  303. Task<IReadOnlyCollection<IGuild>> IDiscordClient.GetGuildsAsync(CacheMode mode, RequestOptions options)
  304. => Task.FromResult<IReadOnlyCollection<IGuild>>(Guilds);
  305. async Task<IGuild> IDiscordClient.CreateGuildAsync(string name, IVoiceRegion region, Stream jpegIcon, RequestOptions options)
  306. => await CreateGuildAsync(name, region, jpegIcon).ConfigureAwait(false);
  307. Task<IUser> IDiscordClient.GetUserAsync(ulong id, CacheMode mode, RequestOptions options)
  308. => Task.FromResult<IUser>(GetUser(id));
  309. Task<IUser> IDiscordClient.GetUserAsync(string username, string discriminator, RequestOptions options)
  310. => Task.FromResult<IUser>(GetUser(username, discriminator));
  311. Task<IReadOnlyCollection<IVoiceRegion>> IDiscordClient.GetVoiceRegionsAsync(RequestOptions options)
  312. => Task.FromResult<IReadOnlyCollection<IVoiceRegion>>(VoiceRegions);
  313. Task<IVoiceRegion> IDiscordClient.GetVoiceRegionAsync(string id, RequestOptions options)
  314. => Task.FromResult<IVoiceRegion>(GetVoiceRegion(id));
  315. }
  316. }