diff --git a/src/Discord.Net.Core/Cache/CacheableEntityExtensions.cs b/src/Discord.Net.Core/Cache/CacheableEntityExtensions.cs new file mode 100644 index 000000000..fb265f94a --- /dev/null +++ b/src/Discord.Net.Core/Cache/CacheableEntityExtensions.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord +{ + internal static class CacheableEntityExtensions + { + public static IActivityModel ToModel(this RichGame richGame) where TModel : WritableActivityModel, new() + { + return new TModel() + { + ApplicationId = richGame.ApplicationId, + SmallImage = richGame.SmallAsset?.ImageId, + SmallText = richGame.SmallAsset?.Text, + LargeImage = richGame.LargeAsset?.ImageId, + LargeText = richGame.LargeAsset?.Text, + Details = richGame.Details, + Flags = richGame.Flags, + Name = richGame.Name, + Type = richGame.Type, + JoinSecret = richGame.Secrets?.Join, + SpectateSecret = richGame.Secrets?.Spectate, + MatchSecret = richGame.Secrets?.Match, + State = richGame.State, + PartyId = richGame.Party?.Id, + PartySize = richGame.Party?.Members != null && richGame.Party?.Capacity != null + ? new long[] { richGame.Party.Members, richGame.Party.Capacity } + : null, + TimestampEnd = richGame.Timestamps?.End, + TimestampStart = richGame.Timestamps?.Start + }; + } + + public static IActivityModel ToModel(this SpotifyGame spotify) where TModel : WritableActivityModel, new() + { + return new TModel() + { + Name = spotify.Name, + SessionId = spotify.SessionId, + SyncId = spotify.TrackId, + LargeText = spotify.AlbumTitle, + Details = spotify.TrackTitle, + State = string.Join(";", spotify.Artists), + TimestampEnd = spotify.EndsAt, + TimestampStart = spotify.StartedAt, + LargeImage = spotify.AlbumArt, + Type = ActivityType.Listening, + Flags = spotify.Flags, + }; + } + + public static IActivityModel ToModel(this CustomStatusGame custom) + where TModel : WritableActivityModel, new() + where TEmoteModel : WritableEmojiModel, new() + { + return new TModel + { + Type = ActivityType.CustomStatus, + Name = custom.Name, + State = custom.State, + Emoji = custom.Emote.ToModel(), + CreatedAt = custom.CreatedAt + }; + } + + public static IActivityModel ToModel(this StreamingGame stream) where TModel : WritableActivityModel, new() + { + return new TModel + { + Name = stream.Name, + Url = stream.Url, + Flags = stream.Flags, + Details = stream.Details + }; + } + + public static IEmojiModel ToModel(this IEmote emote) where TModel : WritableEmojiModel, new() + { + var model = new TModel() + { + Name = emote.Name + }; + + if(emote is GuildEmote guildEmote) + { + model.Id = guildEmote.Id; + model.IsAnimated = guildEmote.Animated; + model.IsAvailable = guildEmote.IsAvailable; + model.IsManaged = guildEmote.IsManaged; + model.CreatorId = guildEmote.CreatorId; + model.RequireColons = guildEmote.RequireColons; + model.Roles = guildEmote.RoleIds.ToArray(); + } + + if(emote is Emote e) + { + model.IsAnimated = e.Animated; + model.Id = e.Id; + } + + return model; + } + } +} diff --git a/src/Discord.Net.Core/Cache/ICached.cs b/src/Discord.Net.Core/Cache/ICached.cs new file mode 100644 index 000000000..3146741bb --- /dev/null +++ b/src/Discord.Net.Core/Cache/ICached.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord +{ + internal interface ICached + { + TType ToModel(); + } +} diff --git a/src/Discord.Net.Core/Cache/Models/Emoji/IEmojiModel.cs b/src/Discord.Net.Core/Cache/Models/Emoji/IEmojiModel.cs new file mode 100644 index 000000000..bc5b43e2a --- /dev/null +++ b/src/Discord.Net.Core/Cache/Models/Emoji/IEmojiModel.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord +{ + public interface IEmojiModel + { + ulong? Id { get; } + string Name { get; } + ulong[] Roles { get; } + bool RequireColons { get; } + bool IsManaged { get; } + bool IsAnimated { get; } + bool IsAvailable { get; } + + ulong? CreatorId { get; } + } + + internal class WritableEmojiModel : IEmojiModel + { + public ulong? Id { get; set; } + public string Name { get; set; } + public ulong[] Roles { get; set; } + public bool RequireColons { get; set; } + public bool IsManaged { get; set; } + public bool IsAnimated { get; set; } + public bool IsAvailable { get; set; } + + public ulong? CreatorId { get; set; } + } +} diff --git a/src/Discord.Net.Core/Cache/Models/Presense/IActivityModel.cs b/src/Discord.Net.Core/Cache/Models/Presense/IActivityModel.cs new file mode 100644 index 000000000..a66e88754 --- /dev/null +++ b/src/Discord.Net.Core/Cache/Models/Presense/IActivityModel.cs @@ -0,0 +1,88 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord +{ + public interface IActivityModel + { + string Id { get; } + string Url { get; } + string Name { get; } + ActivityType Type { get; } + string Details { get; } + string State { get; } + ActivityProperties Flags { get; } + DateTimeOffset CreatedAt { get; } + IEmojiModel Emoji { get; } + ulong? ApplicationId { get; } + string SyncId { get; } + string SessionId { get; } + + + #region Assets + string LargeImage { get; } + string LargeText { get; } + string SmallImage { get; } + string SmallText { get; } + #endregion + + #region Party + string PartyId { get; } + long[] PartySize { get; } + #endregion + + #region Secrets + string JoinSecret { get; } + string SpectateSecret { get; } + string MatchSecret { get; } + #endregion + + #region Timestamps + DateTimeOffset? TimestampStart { get; } + DateTimeOffset? TimestampEnd { get; } + #endregion + } + + internal class WritableActivityModel : IActivityModel + { + public string Id { get; set; } + public string Url { get; set; } + public string Name { get; set; } + public ActivityType Type { get; set; } + public string Details { get; set; } + public string State { get; set; } + public ActivityProperties Flags { get; set; } + public DateTimeOffset CreatedAt { get; set; } + public IEmojiModel Emoji { get; set; } + public ulong? ApplicationId { get; set; } + public string SyncId { get; set; } + public string SessionId { get; set; } + + + #region Assets + public string LargeImage { get; set; } + public string LargeText { get; set; } + public string SmallImage { get; set; } + public string SmallText { get; set; } + #endregion + + #region Party + public string PartyId { get; set; } + public long[] PartySize { get; set; } + #endregion + + #region Secrets + public string JoinSecret { get; set; } + public string SpectateSecret { get; set; } + public string MatchSecret { get; set; } + #endregion + + #region Timestamps + public DateTimeOffset? TimestampStart { get; set; } + public DateTimeOffset? TimestampEnd { get; set; } + #endregion + } +} diff --git a/src/Discord.Net.Core/Cache/Models/Presense/IPresenceModel.cs b/src/Discord.Net.Core/Cache/Models/Presense/IPresenceModel.cs new file mode 100644 index 000000000..c58a1fad5 --- /dev/null +++ b/src/Discord.Net.Core/Cache/Models/Presense/IPresenceModel.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord +{ + public interface IPresenceModel + { + ulong UserId { get; } + ulong? GuildId { get; } + UserStatus Status { get; } + ClientType[] ActiveClients { get; } + IActivityModel[] Activities { get; } + } +} diff --git a/src/Discord.Net.Core/Cache/Models/Users/ICurrentUserModel.cs b/src/Discord.Net.Core/Cache/Models/Users/ICurrentUserModel.cs new file mode 100644 index 000000000..80b832cf8 --- /dev/null +++ b/src/Discord.Net.Core/Cache/Models/Users/ICurrentUserModel.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord +{ + public interface ICurrentUserModel : IUserModel + { + bool? IsVerified { get; } + string Email { get; } + bool? IsMfaEnabled { get; } + UserProperties Flags { get; } + PremiumType PremiumType { get; } + string Locale { get; } + UserProperties PublicFlags { get; } + } +} diff --git a/src/Discord.Net.Core/Cache/Models/Users/IMemberModel.cs b/src/Discord.Net.Core/Cache/Models/Users/IMemberModel.cs new file mode 100644 index 000000000..3c3a67d5e --- /dev/null +++ b/src/Discord.Net.Core/Cache/Models/Users/IMemberModel.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord +{ + public interface IMemberModel + { + IUserModel User { get; } + + string Nickname { get; } + string GuildAvatar { get; } + ulong[] Roles { get; } + DateTimeOffset JoinedAt { get; } + DateTimeOffset? PremiumSince { get; } + bool IsDeaf { get; } + bool IsMute { get; } + bool? IsPending { get; } + DateTimeOffset? CommunicationsDisabledUntil { get; } + } +} diff --git a/src/Discord.Net.Core/Cache/Models/Users/IUserModel.cs b/src/Discord.Net.Core/Cache/Models/Users/IUserModel.cs new file mode 100644 index 000000000..24d05d0fc --- /dev/null +++ b/src/Discord.Net.Core/Cache/Models/Users/IUserModel.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord +{ + public interface IUserModel : IEntity + { + string Username { get; } + string Discriminator { get; } + bool? IsBot { get; } + string Avatar { get; } + } +} diff --git a/src/Discord.Net.Core/Entities/Activities/SpotifyGame.cs b/src/Discord.Net.Core/Entities/Activities/SpotifyGame.cs index 4eab34fa2..abd89808f 100644 --- a/src/Discord.Net.Core/Entities/Activities/SpotifyGame.cs +++ b/src/Discord.Net.Core/Entities/Activities/SpotifyGame.cs @@ -107,6 +107,8 @@ namespace Discord /// public string TrackUrl { get; internal set; } + internal string AlbumArt { get; set; } + internal SpotifyGame() { } /// diff --git a/src/Discord.Net.Core/Entities/Emotes/GuildEmote.cs b/src/Discord.Net.Core/Entities/Emotes/GuildEmote.cs index 4bd0845c8..6fbe1e80c 100644 --- a/src/Discord.Net.Core/Entities/Emotes/GuildEmote.cs +++ b/src/Discord.Net.Core/Entities/Emotes/GuildEmote.cs @@ -24,6 +24,13 @@ namespace Discord /// public bool RequireColons { get; } /// + /// Gets whether or not the emote is available. + /// + /// + /// An emote can be unavailable if the guild has lost its boost status. + /// + public bool IsAvailable { get; } + /// /// Gets the roles that are allowed to use this emoji. /// /// @@ -39,12 +46,13 @@ namespace Discord /// public ulong? CreatorId { get; } - internal GuildEmote(ulong id, string name, bool animated, bool isManaged, bool requireColons, IReadOnlyList roleIds, ulong? userId) : base(id, name, animated) + internal GuildEmote(ulong id, string name, bool animated, bool isManaged, bool isAvailable, bool requireColons, IReadOnlyList roleIds, ulong? userId) : base(id, name, animated) { IsManaged = isManaged; RequireColons = requireColons; RoleIds = roleIds; CreatorId = userId; + IsAvailable = isAvailable; } private string DebuggerDisplay => $"{Name} ({Id})"; diff --git a/src/Discord.Net.Rest/API/Common/CurrentUser.cs b/src/Discord.Net.Rest/API/Common/CurrentUser.cs new file mode 100644 index 000000000..4a19056b2 --- /dev/null +++ b/src/Discord.Net.Rest/API/Common/CurrentUser.cs @@ -0,0 +1,42 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord.API +{ + internal class CurrentUser : User, ICurrentUserModel + { + [JsonProperty("verified")] + public Optional Verified { get; set; } + [JsonProperty("email")] + public Optional Email { get; set; } + [JsonProperty("mfa_enabled")] + public Optional MfaEnabled { get; set; } + [JsonProperty("flags")] + public Optional Flags { get; set; } + [JsonProperty("premium_type")] + public Optional PremiumType { get; set; } + [JsonProperty("locale")] + public Optional Locale { get; set; } + [JsonProperty("public_flags")] + public Optional PublicFlags { get; set; } + + // ICurrentUserModel + bool? ICurrentUserModel.IsVerified => Verified.ToNullable(); + + string ICurrentUserModel.Email => Email.GetValueOrDefault(); + + bool? ICurrentUserModel.IsMfaEnabled => MfaEnabled.ToNullable(); + + UserProperties ICurrentUserModel.Flags => Flags.GetValueOrDefault(); + + PremiumType ICurrentUserModel.PremiumType => PremiumType.GetValueOrDefault(); + + string ICurrentUserModel.Locale => Locale.GetValueOrDefault(); + + UserProperties ICurrentUserModel.PublicFlags => PublicFlags.GetValueOrDefault(); + } +} diff --git a/src/Discord.Net.Rest/API/Common/Emoji.cs b/src/Discord.Net.Rest/API/Common/Emoji.cs index ff0baa73e..af1a25532 100644 --- a/src/Discord.Net.Rest/API/Common/Emoji.cs +++ b/src/Discord.Net.Rest/API/Common/Emoji.cs @@ -2,7 +2,7 @@ using Newtonsoft.Json; namespace Discord.API { - internal class Emoji + internal class Emoji : IEmojiModel { [JsonProperty("id")] public ulong? Id { get; set; } @@ -16,7 +16,25 @@ namespace Discord.API public bool RequireColons { get; set; } [JsonProperty("managed")] public bool Managed { get; set; } + [JsonProperty("available")] + public Optional Available { get; set; } [JsonProperty("user")] public Optional User { get; set; } + + ulong? IEmojiModel.Id => Id; + + string IEmojiModel.Name => Name; + + ulong[] IEmojiModel.Roles => Roles; + + bool IEmojiModel.RequireColons => RequireColons; + + bool IEmojiModel.IsManaged => Managed; + + bool IEmojiModel.IsAnimated => Animated.GetValueOrDefault(); + + bool IEmojiModel.IsAvailable => Available.GetValueOrDefault(); + + ulong? IEmojiModel.CreatorId => User.GetValueOrDefault()?.Id; } } diff --git a/src/Discord.Net.Rest/API/Common/Game.cs b/src/Discord.Net.Rest/API/Common/Game.cs index 105ce0d73..32e0d5c51 100644 --- a/src/Discord.Net.Rest/API/Common/Game.cs +++ b/src/Discord.Net.Rest/API/Common/Game.cs @@ -1,10 +1,11 @@ using Newtonsoft.Json; using Newtonsoft.Json.Serialization; +using System; using System.Runtime.Serialization; namespace Discord.API { - internal class Game + internal class Game : IActivityModel { [JsonProperty("name")] public string Name { get; set; } @@ -32,7 +33,7 @@ namespace Discord.API public Optional SyncId { get; set; } [JsonProperty("session_id")] public Optional SessionId { get; set; } - [JsonProperty("Flags")] + [JsonProperty("flags")] public Optional Flags { get; set; } [JsonProperty("id")] public Optional Id { get; set; } @@ -40,6 +41,54 @@ namespace Discord.API public Optional Emoji { get; set; } [JsonProperty("created_at")] public Optional CreatedAt { get; set; } + + string IActivityModel.Id => Id.GetValueOrDefault(); + + string IActivityModel.Url => StreamUrl.GetValueOrDefault(); + + string IActivityModel.State => State.GetValueOrDefault(); + + IEmojiModel IActivityModel.Emoji => Emoji.GetValueOrDefault(); + + string IActivityModel.Name => Name; + + ActivityType IActivityModel.Type => Type.GetValueOrDefault().GetValueOrDefault(); + + ActivityProperties IActivityModel.Flags => Flags.GetValueOrDefault(); + + string IActivityModel.Details => Details.GetValueOrDefault(); + DateTimeOffset IActivityModel.CreatedAt => DateTimeOffset.FromUnixTimeMilliseconds(CreatedAt.GetValueOrDefault()); + + ulong? IActivityModel.ApplicationId => ApplicationId.ToNullable(); + + string IActivityModel.SyncId => SyncId.GetValueOrDefault(); + + string IActivityModel.SessionId => SessionId.GetValueOrDefault(); + + string IActivityModel.LargeImage => Assets.GetValueOrDefault()?.LargeImage.GetValueOrDefault(); + + string IActivityModel.LargeText => Assets.GetValueOrDefault()?.LargeText.GetValueOrDefault(); + + string IActivityModel.SmallImage => Assets.GetValueOrDefault()?.SmallImage.GetValueOrDefault(); + + string IActivityModel.SmallText => Assets.GetValueOrDefault()?.SmallText.GetValueOrDefault(); + + string IActivityModel.PartyId => Party.GetValueOrDefault()?.Id; + + long[] IActivityModel.PartySize => Party.GetValueOrDefault()?.Size; + + string IActivityModel.JoinSecret => Secrets.GetValueOrDefault()?.Join; + + string IActivityModel.SpectateSecret => Secrets.GetValueOrDefault()?.Spectate; + + string IActivityModel.MatchSecret => Secrets.GetValueOrDefault()?.Match; + + DateTimeOffset? IActivityModel.TimestampStart => Timestamps.GetValueOrDefault()?.Start.ToNullable(); + + DateTimeOffset? IActivityModel.TimestampEnd => Timestamps.GetValueOrDefault()?.End.ToNullable(); + + + //[JsonProperty("buttons")] //public Optional Buttons { get; set; } diff --git a/src/Discord.Net.Rest/API/Common/GuildMember.cs b/src/Discord.Net.Rest/API/Common/GuildMember.cs index cd3101224..cfe4e652e 100644 --- a/src/Discord.Net.Rest/API/Common/GuildMember.cs +++ b/src/Discord.Net.Rest/API/Common/GuildMember.cs @@ -3,7 +3,7 @@ using System; namespace Discord.API { - internal class GuildMember + internal class GuildMember : IMemberModel { [JsonProperty("user")] public User User { get; set; } @@ -25,5 +25,26 @@ namespace Discord.API public Optional PremiumSince { get; set; } [JsonProperty("communication_disabled_until")] public Optional TimedOutUntil { get; set; } + + // IMemberModel + string IMemberModel.Nickname => Nick.GetValueOrDefault(); + + string IMemberModel.GuildAvatar => Avatar.GetValueOrDefault(); + + ulong[] IMemberModel.Roles => Roles.GetValueOrDefault(Array.Empty()); + + DateTimeOffset IMemberModel.JoinedAt => JoinedAt.GetValueOrDefault(); + + DateTimeOffset? IMemberModel.PremiumSince => PremiumSince.GetValueOrDefault(); + + bool IMemberModel.IsDeaf => Deaf.GetValueOrDefault(false); + + bool IMemberModel.IsMute => Mute.GetValueOrDefault(false); + + bool? IMemberModel.IsPending => Pending.ToNullable(); + + DateTimeOffset? IMemberModel.CommunicationsDisabledUntil => TimedOutUntil.GetValueOrDefault(); + + IUserModel IMemberModel.User => User; } } diff --git a/src/Discord.Net.Rest/API/Common/Presence.cs b/src/Discord.Net.Rest/API/Common/Presence.cs index 23f871ae6..173460242 100644 --- a/src/Discord.Net.Rest/API/Common/Presence.cs +++ b/src/Discord.Net.Rest/API/Common/Presence.cs @@ -1,10 +1,11 @@ using Newtonsoft.Json; using System; using System.Collections.Generic; +using System.Linq; namespace Discord.API { - internal class Presence + internal class Presence : IPresenceModel { [JsonProperty("user")] public User User { get; set; } @@ -28,5 +29,17 @@ namespace Discord.API public List Activities { get; set; } [JsonProperty("premium_since")] public Optional PremiumSince { get; set; } + + ulong IPresenceModel.UserId => User.Id; + + ulong? IPresenceModel.GuildId => GuildId.ToNullable(); + + UserStatus IPresenceModel.Status => Status; + + ClientType[] IPresenceModel.ActiveClients => ClientStatus.IsSpecified + ? ClientStatus.Value.Select(x => (ClientType)Enum.Parse(typeof(ClientType), x.Key, true)).ToArray() + : Array.Empty(); + + IActivityModel[] IPresenceModel.Activities => Activities.ToArray(); } } diff --git a/src/Discord.Net.Rest/API/Common/User.cs b/src/Discord.Net.Rest/API/Common/User.cs index 08fe88cb0..26e5deafd 100644 --- a/src/Discord.Net.Rest/API/Common/User.cs +++ b/src/Discord.Net.Rest/API/Common/User.cs @@ -2,7 +2,7 @@ using Newtonsoft.Json; namespace Discord.API { - internal class User + internal class User : IUserModel { [JsonProperty("id")] public ulong Id { get; set; } @@ -19,20 +19,16 @@ namespace Discord.API [JsonProperty("accent_color")] public Optional AccentColor { get; set; } - //CurrentUser - [JsonProperty("verified")] - public Optional Verified { get; set; } - [JsonProperty("email")] - public Optional Email { get; set; } - [JsonProperty("mfa_enabled")] - public Optional MfaEnabled { get; set; } - [JsonProperty("flags")] - public Optional Flags { get; set; } - [JsonProperty("premium_type")] - public Optional PremiumType { get; set; } - [JsonProperty("locale")] - public Optional Locale { get; set; } - [JsonProperty("public_flags")] - public Optional PublicFlags { get; set; } + + // IUserModel + string IUserModel.Username => Username.GetValueOrDefault(); + + string IUserModel.Discriminator => Discriminator.GetValueOrDefault(); + + bool? IUserModel.IsBot => Bot.ToNullable(); + + string IUserModel.Avatar => Avatar.GetValueOrDefault(); + + ulong IEntity.Id => Id; } } diff --git a/src/Discord.Net.Rest/ClientHelper.cs b/src/Discord.Net.Rest/ClientHelper.cs index c6ad6a9fb..ab0238fee 100644 --- a/src/Discord.Net.Rest/ClientHelper.cs +++ b/src/Discord.Net.Rest/ClientHelper.cs @@ -151,6 +151,16 @@ namespace Discord.Rest return null; } + public static async Task> GetGuildUsersAsync(BaseDiscordClient client, + ulong guildId, RequestOptions options) + { + var guild = await GetGuildAsync(client, guildId, false, options).ConfigureAwait(false); + if (guild == null) + return null; + + return (await GuildHelper.GetUsersAsync(guild, client, null, null, options).FlattenAsync()).ToImmutableArray(); + } + public static async Task GetWebhookAsync(BaseDiscordClient client, ulong id, RequestOptions options) { var model = await client.ApiClient.GetWebhookAsync(id).ConfigureAwait(false); diff --git a/src/Discord.Net.Rest/DiscordRestApiClient.cs b/src/Discord.Net.Rest/DiscordRestApiClient.cs index 3b829ee17..60a95c6e3 100644 --- a/src/Discord.Net.Rest/DiscordRestApiClient.cs +++ b/src/Discord.Net.Rest/DiscordRestApiClient.cs @@ -2063,10 +2063,10 @@ namespace Discord.API #endregion #region Current User/DMs - public async Task GetMyUserAsync(RequestOptions options = null) + public async Task GetMyUserAsync(RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - return await SendAsync("GET", () => "users/@me", new BucketIds(), options: options).ConfigureAwait(false); + return await SendAsync("GET", () => "users/@me", new BucketIds(), options: options).ConfigureAwait(false); } public async Task> GetMyConnectionsAsync(RequestOptions options = null) { diff --git a/src/Discord.Net.Rest/DiscordRestClient.cs b/src/Discord.Net.Rest/DiscordRestClient.cs index b1948f80a..6ae8bc0b9 100644 --- a/src/Discord.Net.Rest/DiscordRestClient.cs +++ b/src/Discord.Net.Rest/DiscordRestClient.cs @@ -185,6 +185,8 @@ namespace Discord.Rest => ClientHelper.GetUserAsync(this, id, options); public Task GetGuildUserAsync(ulong guildId, ulong id, RequestOptions options = null) => ClientHelper.GetGuildUserAsync(this, guildId, id, options); + public Task> GetGuildUsersAsync(ulong guildId, RequestOptions options = null) + => ClientHelper.GetGuildUsersAsync(this, guildId, options); public Task> GetVoiceRegionsAsync(RequestOptions options = null) => ClientHelper.GetVoiceRegionsAsync(this, options); diff --git a/src/Discord.Net.Rest/Entities/Users/RestGuildUser.cs b/src/Discord.Net.Rest/Entities/Users/RestGuildUser.cs index 0a4a33099..91d716bf4 100644 --- a/src/Discord.Net.Rest/Entities/Users/RestGuildUser.cs +++ b/src/Discord.Net.Rest/Entities/Users/RestGuildUser.cs @@ -94,6 +94,7 @@ namespace Discord.Rest internal void Update(Model model) { base.Update(model.User); + if (model.JoinedAt.IsSpecified) _joinedAtTicks = model.JoinedAt.Value.UtcTicks; if (model.Nick.IsSpecified) diff --git a/src/Discord.Net.Rest/Entities/Users/RestSelfUser.cs b/src/Discord.Net.Rest/Entities/Users/RestSelfUser.cs index b5ef01c53..49ef92a64 100644 --- a/src/Discord.Net.Rest/Entities/Users/RestSelfUser.cs +++ b/src/Discord.Net.Rest/Entities/Users/RestSelfUser.cs @@ -1,7 +1,8 @@ using System; using System.Diagnostics; using System.Threading.Tasks; -using Model = Discord.API.User; +using UserModel = Discord.API.User; +using Model = Discord.API.CurrentUser; namespace Discord.Rest { @@ -28,29 +29,26 @@ namespace Discord.Rest : base(discord, id) { } - internal new static RestSelfUser Create(BaseDiscordClient discord, Model model) + internal new static RestSelfUser Create(BaseDiscordClient discord, UserModel model) { var entity = new RestSelfUser(discord, model.Id); entity.Update(model); return entity; } /// - internal override void Update(Model model) + internal override void Update(UserModel model) { base.Update(model); - if (model.Email.IsSpecified) - Email = model.Email.Value; - if (model.Verified.IsSpecified) - IsVerified = model.Verified.Value; - if (model.MfaEnabled.IsSpecified) - IsMfaEnabled = model.MfaEnabled.Value; - if (model.Flags.IsSpecified) - Flags = (UserProperties)model.Flags.Value; - if (model.PremiumType.IsSpecified) - PremiumType = model.PremiumType.Value; - if (model.Locale.IsSpecified) - Locale = model.Locale.Value; + if (model is not Model currentUserModel) + throw new ArgumentException("Got unexpected model type when updating RestSelfUser"); + + Email = currentUserModel.Email.GetValueOrDefault(); + IsVerified = currentUserModel.Verified.GetValueOrDefault(false); + IsMfaEnabled = currentUserModel.MfaEnabled.GetValueOrDefault(false); + Flags = currentUserModel.Flags.GetValueOrDefault(); + PremiumType = currentUserModel.PremiumType.GetValueOrDefault(); + Locale = currentUserModel.Locale.GetValueOrDefault(); } /// diff --git a/src/Discord.Net.Rest/Entities/Users/RestUser.cs b/src/Discord.Net.Rest/Entities/Users/RestUser.cs index dfdb53815..9074a88e2 100644 --- a/src/Discord.Net.Rest/Entities/Users/RestUser.cs +++ b/src/Discord.Net.Rest/Entities/Users/RestUser.cs @@ -78,20 +78,16 @@ namespace Discord.Rest internal virtual void Update(Model model) { - if (model.Avatar.IsSpecified) - AvatarId = model.Avatar.Value; - if (model.Banner.IsSpecified) - BannerId = model.Banner.Value; - if (model.AccentColor.IsSpecified) - AccentColor = model.AccentColor.Value; - if (model.Discriminator.IsSpecified) + AvatarId = model.Avatar.GetValueOrDefault(); + if(model.Discriminator.IsSpecified) DiscriminatorValue = ushort.Parse(model.Discriminator.Value, NumberStyles.None, CultureInfo.InvariantCulture); - if (model.Bot.IsSpecified) - IsBot = model.Bot.Value; - if (model.Username.IsSpecified) - Username = model.Username.Value; - if (model.PublicFlags.IsSpecified) - PublicFlags = model.PublicFlags.Value; + IsBot = model.Bot.GetValueOrDefault(false); + Username = model.Username.GetValueOrDefault(); + + if(model is ICurrentUserModel currentUserModel) + { + PublicFlags = currentUserModel.PublicFlags; + } } /// diff --git a/src/Discord.Net.Rest/Extensions/EntityExtensions.cs b/src/Discord.Net.Rest/Extensions/EntityExtensions.cs index 4062cda3d..5de5b7b5d 100644 --- a/src/Discord.Net.Rest/Extensions/EntityExtensions.cs +++ b/src/Discord.Net.Rest/Extensions/EntityExtensions.cs @@ -6,6 +6,23 @@ namespace Discord.Rest { internal static class EntityExtensions { + public static IEmote ToIEmote(this IEmojiModel model) + { + if (model.Id.HasValue) + return model.ToEntity(); + return new Emoji(model.Name); + } + + public static GuildEmote ToEntity(this IEmojiModel model) + => new GuildEmote(model.Id.Value, + model.Name, + model.IsAnimated, + model.IsManaged, + model.IsAvailable, + model.RequireColons, + ImmutableArray.Create(model.Roles), + model.CreatorId); + public static IEmote ToIEmote(this API.Emoji model) { if (model.Id.HasValue) @@ -18,6 +35,7 @@ namespace Discord.Rest model.Name, model.Animated.GetValueOrDefault(), model.Managed, + model.Available.GetValueOrDefault(), model.RequireColons, ImmutableArray.Create(model.Roles), model.User.IsSpecified ? model.User.Value.Id : (ulong?)null); @@ -170,48 +188,5 @@ namespace Discord.Rest { return new Overwrite(model.TargetId, model.TargetType, new OverwritePermissions(model.Allow, model.Deny)); } - - public static API.Message ToMessage(this API.InteractionResponse model, IDiscordInteraction interaction) - { - if (model.Data.IsSpecified) - { - var data = model.Data.Value; - var messageModel = new API.Message - { - IsTextToSpeech = data.TTS, - Content = (data.Content.IsSpecified && data.Content.Value == null) ? Optional.Unspecified : data.Content, - Embeds = data.Embeds, - AllowedMentions = data.AllowedMentions, - Components = data.Components, - Flags = data.Flags, - }; - - if(interaction is IApplicationCommandInteraction command) - { - messageModel.Interaction = new API.MessageInteraction - { - Id = command.Id, - Name = command.Data.Name, - Type = InteractionType.ApplicationCommand, - User = new API.User - { - Username = command.User.Username, - Avatar = command.User.AvatarId, - Bot = command.User.IsBot, - Discriminator = command.User.Discriminator, - PublicFlags = command.User.PublicFlags.HasValue ? command.User.PublicFlags.Value : Optional.Unspecified, - Id = command.User.Id, - } - }; - } - - return messageModel; - } - - return new API.Message - { - Id = interaction.Id, - }; - } } } diff --git a/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs b/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs index e0b5fc0b5..10b7adf2e 100644 --- a/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs +++ b/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs @@ -17,7 +17,7 @@ namespace Discord.API.Gateway [JsonProperty("v")] public int Version { get; set; } [JsonProperty("user")] - public User User { get; set; } + public CurrentUser User { get; set; } [JsonProperty("session_id")] public string SessionId { get; set; } [JsonProperty("read_state")] diff --git a/src/Discord.Net.WebSocket/Cache/CacheRunMode.cs b/src/Discord.Net.WebSocket/Cache/CacheRunMode.cs new file mode 100644 index 000000000..f53719e06 --- /dev/null +++ b/src/Discord.Net.WebSocket/Cache/CacheRunMode.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord.WebSocket +{ + public enum CacheRunMode + { + /// + /// The cache should preform a synchronous cache lookup. + /// + Sync, + + /// + /// The cache should preform either a or asynchronous cache lookup. + /// + Async + } +} diff --git a/src/Discord.Net.WebSocket/Cache/DefaultConcurrentCacheProvider.cs b/src/Discord.Net.WebSocket/Cache/DefaultConcurrentCacheProvider.cs new file mode 100644 index 000000000..f0131759d --- /dev/null +++ b/src/Discord.Net.WebSocket/Cache/DefaultConcurrentCacheProvider.cs @@ -0,0 +1,82 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord.WebSocket +{ + public class DefaultConcurrentCacheProvider : ICacheProvider + { + private readonly ConcurrentDictionary _users; + private readonly ConcurrentDictionary> _members; + private readonly ConcurrentDictionary _presense; + + private ValueTask CompletedValueTask => new ValueTask(Task.CompletedTask).Preserve(); + + public DefaultConcurrentCacheProvider(int defaultConcurrency, int defaultCapacity) + { + _users = new(defaultConcurrency, defaultCapacity); + _members = new(defaultConcurrency, defaultCapacity); + _presense = new(defaultConcurrency, defaultCapacity); + } + + public ValueTask AddOrUpdateUserAsync(IUserModel model, CacheRunMode mode) + { + _users.AddOrUpdate(model.Id, model, (_, __) => model); + return CompletedValueTask; + } + public ValueTask AddOrUpdateMemberAsync(IMemberModel model, ulong guildId, CacheRunMode mode) + { + var guildMemberCache = _members.GetOrAdd(guildId, (_) => new ConcurrentDictionary()); + guildMemberCache.AddOrUpdate(model.User.Id, model, (_, __) => model); + return CompletedValueTask; + } + public ValueTask GetMemberAsync(ulong id, ulong guildId, CacheRunMode mode) + => new ValueTask(_members.FirstOrDefault(x => x.Key == guildId).Value?.FirstOrDefault(x => x.Key == id).Value); + + public ValueTask> GetMembersAsync(ulong guildId, CacheRunMode mode) + { + if(_members.TryGetValue(guildId, out var inner)) + return new ValueTask>(inner.ToArray().Select(x => x.Value)); // ToArray here is important before .Select due to concurrency + return new ValueTask>(Array.Empty()); + } + public ValueTask GetUserAsync(ulong id, CacheRunMode mode) + { + if (_users.TryGetValue(id, out var result)) + return new ValueTask(result); + return new ValueTask((IUserModel)null); + } + public ValueTask> GetUsersAsync(CacheRunMode mode) + => new ValueTask>(_users.ToArray().Select(x => x.Value)); + public ValueTask RemoveMemberAsync(ulong id, ulong guildId, CacheRunMode mode) + { + if (_members.TryGetValue(guildId, out var inner)) + inner.TryRemove(id, out var _); + return CompletedValueTask; + } + public ValueTask RemoveUserAsync(ulong id, CacheRunMode mode) + { + _members.TryRemove(id, out var _); + return CompletedValueTask; + } + + public ValueTask GetPresenceAsync(ulong userId, CacheRunMode runmode) + { + if (_presense.TryGetValue(userId, out var presense)) + return new ValueTask(presense); + return new ValueTask((IPresenceModel)null); + } + public ValueTask AddOrUpdatePresenseAsync(ulong userId, IPresenceModel presense, CacheRunMode runmode) + { + _presense.AddOrUpdate(userId, presense, (_, __) => presense); + return CompletedValueTask; + } + public ValueTask RemovePresenseAsync(ulong userId, CacheRunMode runmode) + { + _presense.TryRemove(userId, out var _); + return CompletedValueTask; + } + } +} diff --git a/src/Discord.Net.WebSocket/Cache/ICacheProvider.cs b/src/Discord.Net.WebSocket/Cache/ICacheProvider.cs new file mode 100644 index 000000000..265580193 --- /dev/null +++ b/src/Discord.Net.WebSocket/Cache/ICacheProvider.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord.WebSocket +{ + public interface ICacheProvider + { + #region Users + + ValueTask GetUserAsync(ulong id, CacheRunMode runmode); + ValueTask> GetUsersAsync(CacheRunMode runmode); + ValueTask AddOrUpdateUserAsync(IUserModel model, CacheRunMode runmode); + ValueTask RemoveUserAsync(ulong id, CacheRunMode runmode); + + #endregion + + #region Members + + ValueTask GetMemberAsync(ulong id, ulong guildId, CacheRunMode runmode); + ValueTask> GetMembersAsync(ulong guildId, CacheRunMode runmode); + ValueTask AddOrUpdateMemberAsync(IMemberModel model, ulong guildId, CacheRunMode runmode); + ValueTask RemoveMemberAsync(ulong id, ulong guildId, CacheRunMode runmode); + + #endregion + + #region Presence + + ValueTask GetPresenceAsync(ulong userId, CacheRunMode runmode); + ValueTask AddOrUpdatePresenseAsync(ulong userId, IPresenceModel presense, CacheRunMode runmode); + ValueTask RemovePresenseAsync(ulong userId, CacheRunMode runmode); + + #endregion + } +} diff --git a/src/Discord.Net.WebSocket/ClientStateManager.Experiment.cs b/src/Discord.Net.WebSocket/ClientStateManager.Experiment.cs new file mode 100644 index 000000000..cc59c57aa --- /dev/null +++ b/src/Discord.Net.WebSocket/ClientStateManager.Experiment.cs @@ -0,0 +1,160 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord.WebSocket +{ + internal class CacheWeakReference : WeakReference + { + public new T Target { get => (T)base.Target; set => base.Target = value; } + public CacheWeakReference(T target) + : base(target, false) + { + + } + + public bool TryGetTarget(out T target) + { + target = Target; + return IsAlive; + } + } + + internal partial class ClientStateManager + { + private readonly ConcurrentDictionary> _userReferences = new(); + private readonly ConcurrentDictionary<(ulong GuildId, ulong UserId), CacheWeakReference> _memberReferences = new(); + + + #region Helpers + + private void EnsureSync(ValueTask vt) + { + if (!vt.IsCompleted) + throw new NotSupportedException($"Cannot use async context for value task lookup"); + } + + #endregion + + #region Global users + internal void RemoveReferencedGlobalUser(ulong id) + => _userReferences.TryRemove(id, out _); + + private void TrackGlobalUser(ulong id, SocketGlobalUser user) + { + if (user != null) + { + _userReferences.TryAdd(id, new CacheWeakReference(user)); + } + } + + internal ValueTask GetUserAsync(ulong id, CacheMode mode = CacheMode.AllowDownload, RequestOptions options = null) + => _state.GetUserAsync(id, mode.ToBehavior(), options); + + internal SocketGlobalUser GetUser(ulong id) + { + if (_userReferences.TryGetValue(id, out var userRef) && userRef.TryGetTarget(out var user)) + return user; + + user = (SocketGlobalUser)_state.GetUserAsync(id, StateBehavior.SyncOnly).Result; + + if(user != null) + TrackGlobalUser(id, user); + + return user; + } + + internal SocketGlobalUser GetOrAddUser(ulong id, Func userFactory) + { + if (_userReferences.TryGetValue(id, out var userRef) && userRef.TryGetTarget(out var user)) + return user; + + user = GetUser(id); + + if (user == null) + { + user ??= userFactory(id); + _state.AddOrUpdateUserAsync(user); + TrackGlobalUser(id, user); + } + + return user; + } + + internal void RemoveUser(ulong id) + { + _state.RemoveUserAsync(id); + } + #endregion + + #region GuildUsers + private void TrackMember(ulong userId, ulong guildId, SocketGuildUser user) + { + if(user != null) + { + _memberReferences.TryAdd((guildId, userId), new CacheWeakReference(user)); + } + } + internal void RemovedReferencedMember(ulong userId, ulong guildId) + => _memberReferences.TryRemove((guildId, userId), out _); + + internal ValueTask GetMemberAsync(ulong userId, ulong guildId, CacheMode mode = CacheMode.AllowDownload, RequestOptions options = null) + => _state.GetMemberAsync(guildId, userId, mode.ToBehavior(), options); + + internal SocketGuildUser GetMember(ulong userId, ulong guildId) + { + if (_memberReferences.TryGetValue((guildId, userId), out var memberRef) && memberRef.TryGetTarget(out var member)) + return member; + member = (SocketGuildUser)_state.GetMemberAsync(guildId, userId, StateBehavior.SyncOnly).Result; + if(member != null) + TrackMember(userId, guildId, member); + return member; + } + + internal SocketGuildUser GetOrAddMember(ulong userId, ulong guildId, Func memberFactory) + { + if (_memberReferences.TryGetValue((guildId, userId), out var memberRef) && memberRef.TryGetTarget(out var member)) + return member; + + member = GetMember(userId, guildId); + + if (member == null) + { + member ??= memberFactory(userId, guildId); + TrackMember(userId, guildId, member); + Task.Run(async () => await _state.AddOrUpdateMemberAsync(guildId, member)); // can run async, think of this as fire and forget. + } + + return member; + } + + internal IEnumerable GetMembers(ulong guildId) + => _state.GetMembersAsync(guildId, StateBehavior.SyncOnly).Result; + + internal void AddOrUpdateMember(ulong guildId, SocketGuildUser user) + => EnsureSync(_state.AddOrUpdateMemberAsync(guildId, user)); + + internal void RemoveMember(ulong userId, ulong guildId) + => EnsureSync(_state.RemoveMemberAsync(guildId, userId)); + + #endregion + + #region Presence + internal void AddOrUpdatePresence(SocketPresence presence) + { + EnsureSync(_state.AddOrUpdatePresenseAsync(presence.UserId, presence, StateBehavior.SyncOnly)); + } + + internal SocketPresence GetPresence(ulong userId) + { + if (_state.GetPresenceAsync(userId, StateBehavior.SyncOnly).Result is not SocketPresence socketPresence) + throw new NotSupportedException("Cannot use non-socket entity for presence"); + + return socketPresence; + } + #endregion + } +} diff --git a/src/Discord.Net.WebSocket/ClientState.cs b/src/Discord.Net.WebSocket/ClientStateManager.cs similarity index 91% rename from src/Discord.Net.WebSocket/ClientState.cs rename to src/Discord.Net.WebSocket/ClientStateManager.cs index c40ae3f92..1416e9cf9 100644 --- a/src/Discord.Net.WebSocket/ClientState.cs +++ b/src/Discord.Net.WebSocket/ClientStateManager.cs @@ -5,7 +5,7 @@ using System.Linq; namespace Discord.WebSocket { - internal class ClientState + internal partial class ClientStateManager { private const double AverageChannelsPerGuild = 10.22; //Source: Googie2149 private const double AverageUsersPerGuild = 47.78; //Source: Googie2149 @@ -30,8 +30,11 @@ namespace Discord.WebSocket _groupChannels.Select(x => GetChannel(x) as ISocketPrivateChannel)) .ToReadOnlyCollection(() => _dmChannels.Count + _groupChannels.Count); - public ClientState(int guildCount, int dmChannelCount) + private readonly IStateProvider _state; + + public ClientStateManager(IStateProvider state, int guildCount, int dmChannelCount) { + _state = state; double estimatedChannelCount = guildCount * AverageChannelsPerGuild + dmChannelCount; double estimatedUsersCount = guildCount * AverageUsersPerGuild; _channels = new ConcurrentDictionary(ConcurrentHashSet.DefaultConcurrencyLevel, (int)(estimatedChannelCount * CollectionMultiplier)); @@ -121,22 +124,6 @@ namespace Discord.WebSocket return null; } - internal SocketGlobalUser GetUser(ulong id) - { - if (_users.TryGetValue(id, out SocketGlobalUser user)) - return user; - return null; - } - internal SocketGlobalUser GetOrAddUser(ulong id, Func userFactory) - { - return _users.GetOrAdd(id, userFactory); - } - internal SocketGlobalUser RemoveUser(ulong id) - { - if (_users.TryRemove(id, out SocketGlobalUser user)) - return user; - return null; - } internal void PurgeUsers() { foreach (var guild in _guilds.Values) diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index 3a14692e0..25fd2abb8 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -200,7 +200,7 @@ namespace Discord.WebSocket return _shards[id]; return null; } - private int GetShardIdFor(ulong guildId) + public int GetShardIdFor(ulong guildId) => (int)((guildId >> 22) % (uint)_totalShards); public int GetShardIdFor(IGuild guild) => GetShardIdFor(guild?.Id ?? 0); diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index aaef4656a..355dec006 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -70,9 +70,10 @@ namespace Discord.WebSocket internal int TotalShards { get; private set; } internal int MessageCacheSize { get; private set; } internal int LargeThreshold { get; private set; } - internal ClientState State { get; private set; } + internal ClientStateManager StateManager { get; private set; } internal UdpSocketProvider UdpSocketProvider { get; private set; } internal WebSocketProvider WebSocketProvider { get; private set; } + internal IStateProvider StateProvider { get; private set; } internal bool AlwaysDownloadUsers { get; private set; } internal int? HandlerTimeout { get; private set; } internal bool AlwaysDownloadDefaultStickers { get; private set; } @@ -81,7 +82,7 @@ namespace Discord.WebSocket internal bool SuppressUnknownDispatchWarnings { get; private set; } internal new DiscordSocketApiClient ApiClient => base.ApiClient; /// - public override IReadOnlyCollection Guilds => State.Guilds; + public override IReadOnlyCollection Guilds => StateManager.Guilds; /// public override IReadOnlyCollection> DefaultStickerPacks { @@ -94,7 +95,7 @@ namespace Discord.WebSocket } } /// - public override IReadOnlyCollection PrivateChannels => State.PrivateChannels; + public override IReadOnlyCollection PrivateChannels => StateManager.PrivateChannels; /// /// Gets a collection of direct message channels opened in this session. /// @@ -109,7 +110,7 @@ namespace Discord.WebSocket /// A collection of DM channels that have been opened in this session. /// public IReadOnlyCollection DMChannels - => State.PrivateChannels.OfType().ToImmutableArray(); + => StateManager.PrivateChannels.OfType().ToImmutableArray(); /// /// Gets a collection of group channels opened in this session. /// @@ -124,7 +125,7 @@ namespace Discord.WebSocket /// A collection of group channels that have been opened in this session. /// public IReadOnlyCollection GroupChannels - => State.PrivateChannels.OfType().ToImmutableArray(); + => StateManager.PrivateChannels.OfType().ToImmutableArray(); /// /// Initializes a new REST/WebSocket-based Discord client. @@ -141,6 +142,7 @@ namespace Discord.WebSocket private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, DiscordShardedClient shardedClient, DiscordSocketClient parentClient) : base(config, client) { + // TODO: config concurrency and size ShardId = config.ShardId ?? 0; TotalShards = config.TotalShards ?? 1; MessageCacheSize = config.MessageCacheSize; @@ -153,7 +155,6 @@ namespace Discord.WebSocket LogGatewayIntentWarnings = config.LogGatewayIntentWarnings; SuppressUnknownDispatchWarnings = config.SuppressUnknownDispatchWarnings; HandlerTimeout = config.HandlerTimeout; - State = new ClientState(0, 0); Rest = new DiscordSocketRestClient(config, ApiClient); _heartbeatTimes = new ConcurrentQueue(); _gatewayIntents = config.GatewayIntents; @@ -165,6 +166,7 @@ namespace Discord.WebSocket OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x); _connection.Connected += () => TimedInvokeAsync(_connectedEvent, nameof(Connected)); _connection.Disconnected += (ex, recon) => TimedInvokeAsync(_disconnectedEvent, nameof(Disconnected), ex); + StateProvider = config.StateProvider ?? new DefaultStateProvider(_gatewayLogger, config.CacheProvider ?? new DefaultConcurrentCacheProvider(5, 50), this, config.DefaultStateBehavior); _nextAudioId = 1; _shardedClient = shardedClient; @@ -200,6 +202,17 @@ namespace Discord.WebSocket private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) => new DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayHost, useSystemClock: config.UseSystemClock, defaultRatelimitCallback: config.DefaultRatelimitCallback); + + #region State + + public ValueTask GetUserAsync(ulong id, CacheMode cacheMode = CacheMode.AllowDownload, RequestOptions options = null) + => StateManager.GetUserAsync(id, cacheMode, options); + + public ValueTask GetGuildUserAsync(ulong userId, ulong guildId, CacheMode cacheMode = CacheMode.AllowDownload, RequestOptions options = null) + => StateManager.GetMemberAsync(userId, guildId, cacheMode, options); + + #endregion + /// internal override void Dispose(bool disposing) { @@ -217,7 +230,6 @@ namespace Discord.WebSocket base.Dispose(disposing); } - internal override async ValueTask DisposeAsync(bool disposing) { if (!_isDisposed) @@ -348,7 +360,7 @@ namespace Discord.WebSocket //Raise virtual GUILD_UNAVAILABLEs await _gatewayLogger.DebugAsync("Raising virtual GuildUnavailables").ConfigureAwait(false); - foreach (var guild in State.Guilds) + foreach (var guild in StateManager.Guilds) { if (guild.IsAvailable) await GuildUnavailableAsync(guild).ConfigureAwait(false); @@ -361,11 +373,11 @@ namespace Discord.WebSocket /// public override SocketGuild GetGuild(ulong id) - => State.GetGuild(id); + => StateManager.GetGuild(id); /// public override SocketChannel GetChannel(ulong id) - => State.GetChannel(id); + => StateManager.GetChannel(id); /// /// Gets a generic channel from the cache or does a rest request if unavailable. /// @@ -387,27 +399,9 @@ namespace Discord.WebSocket public async ValueTask GetChannelAsync(ulong id, RequestOptions options = null) => GetChannel(id) ?? (IChannel)await ClientHelper.GetChannelAsync(this, id, options).ConfigureAwait(false); /// - /// Gets a user from the cache or does a rest request if unavailable. - /// - /// - /// - /// var user = await _client.GetUserAsync(168693960628371456); - /// if (user != null) - /// Console.WriteLine($"{user} is created at {user.CreatedAt}."; - /// - /// - /// The snowflake identifier of the user (e.g. `168693960628371456`). - /// The options to be used when sending the request. - /// - /// A task that represents the asynchronous get operation. The task result contains the user associated with - /// the snowflake identifier; null if the user is not found. - /// - public async ValueTask GetUserAsync(ulong id, RequestOptions options = null) - => await ClientHelper.GetUserAsync(this, id, options).ConfigureAwait(false); - /// /// Clears all cached channels from the client. /// - public void PurgeChannelCache() => State.PurgeAllChannels(); + public void PurgeChannelCache() => StateManager.PurgeAllChannels(); /// /// Clears cached DM channels from the client. /// @@ -415,10 +409,10 @@ namespace Discord.WebSocket /// public override SocketUser GetUser(ulong id) - => State.GetUser(id); + => StateManager.GetUser(id); /// public override SocketUser GetUser(string username, string discriminator) - => State.Users.FirstOrDefault(x => x.Discriminator == discriminator && x.Username == username); + => StateManager.Users.FirstOrDefault(x => x.Discriminator == discriminator && x.Username == username); /// /// Gets a global application command. @@ -431,7 +425,7 @@ namespace Discord.WebSocket /// public async ValueTask GetGlobalApplicationCommandAsync(ulong id, RequestOptions options = null) { - var command = State.GetCommand(id); + var command = StateManager.GetCommand(id); if (command != null) return command; @@ -443,7 +437,7 @@ namespace Discord.WebSocket command = SocketApplicationCommand.Create(this, model); - State.AddCommand(command); + StateManager.AddCommand(command); return command; } @@ -461,7 +455,7 @@ namespace Discord.WebSocket foreach(var command in commands) { - State.AddCommand(command); + StateManager.AddCommand(command); } return commands.ToImmutableArray(); @@ -471,7 +465,7 @@ namespace Discord.WebSocket { var model = await InteractionHelper.CreateGlobalCommandAsync(this, properties, options).ConfigureAwait(false); - var entity = State.GetOrAddCommand(model.Id, (id) => SocketApplicationCommand.Create(this, model)); + var entity = StateManager.GetOrAddCommand(model.Id, (id) => SocketApplicationCommand.Create(this, model)); //Update it in case it was cached entity.Update(model); @@ -486,11 +480,11 @@ namespace Discord.WebSocket var entities = models.Select(x => SocketApplicationCommand.Create(this, x)); //Purge our previous commands - State.PurgeCommands(x => x.IsGlobalCommand); + StateManager.PurgeCommands(x => x.IsGlobalCommand); foreach(var entity in entities) { - State.AddCommand(entity); + StateManager.AddCommand(entity); } return entities.ToImmutableArray(); @@ -499,27 +493,26 @@ namespace Discord.WebSocket /// /// Clears cached users from the client. /// - public void PurgeUserCache() => State.PurgeUsers(); - internal SocketGlobalUser GetOrCreateUser(ClientState state, Discord.API.User model) + public void PurgeUserCache() => StateManager.PurgeUsers(); + internal SocketGlobalUser GetOrCreateUser(ClientStateManager state, IUserModel model) { return state.GetOrAddUser(model.Id, x => SocketGlobalUser.Create(this, state, model)); } - internal SocketUser GetOrCreateTemporaryUser(ClientState state, Discord.API.User model) + internal SocketUser GetOrCreateTemporaryUser(ClientStateManager state, Discord.API.User model) { return state.GetUser(model.Id) ?? (SocketUser)SocketUnknownUser.Create(this, state, model); } - internal SocketGlobalUser GetOrCreateSelfUser(ClientState state, Discord.API.User model) + internal SocketGlobalUser GetOrCreateSelfUser(ClientStateManager state, ICurrentUserModel model) { return state.GetOrAddUser(model.Id, x => { var user = SocketGlobalUser.Create(this, state, model); user.GlobalUser.AddRef(); - user.Presence = new SocketPresence(UserStatus.Online, null, null); return user; }); } internal void RemoveUser(ulong id) - => State.RemoveUser(id); + => StateManager.RemoveUser(id); /// public override async Task GetStickerAsync(ulong id, CacheMode mode = CacheMode.AllowDownload, RequestOptions options = null) @@ -548,7 +541,7 @@ namespace Discord.WebSocket if (model.GuildId.IsSpecified) { - var guild = State.GetGuild(model.GuildId.Value); + var guild = StateManager.GetGuild(model.GuildId.Value); //Since the sticker can be from another guild, check if we are in the guild or its in the cache if (guild != null) @@ -696,7 +689,7 @@ namespace Discord.WebSocket if (CurrentUser == null) return; var activities = _activity.IsSpecified ? ImmutableList.Create(_activity.Value) : null; - CurrentUser.Presence = new SocketPresence(Status, null, activities); + StateManager.AddOrUpdatePresence(new SocketPresence(Status, null, activities)); var presence = BuildCurrentStatus() ?? (UserStatus.Online, false, null, null); @@ -866,12 +859,13 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (READY)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var state = new ClientState(data.Guilds.Length, data.PrivateChannels.Length); + var state = new ClientStateManager(StateProvider, data.Guilds.Length, data.PrivateChannels.Length); + StateManager = state; var currentUser = SocketSelfUser.Create(this, state, data.User); Rest.CreateRestSelfUser(data.User); var activities = _activity.IsSpecified ? ImmutableList.Create(_activity.Value) : null; - currentUser.Presence = new SocketPresence(Status, null, activities); + StateManager.AddOrUpdatePresence(new SocketPresence(Status, null, activities)); ApiClient.CurrentUserId = currentUser.Id; ApiClient.CurrentApplicationId = data.Application.Id; Rest.CurrentUser = RestSelfUser.Create(this, data.User); @@ -892,7 +886,6 @@ namespace Discord.WebSocket _unavailableGuildCount = unavailableGuilds; CurrentUser = currentUser; _previousSessionUser = CurrentUser; - State = state; } catch (Exception ex) { @@ -928,7 +921,7 @@ namespace Discord.WebSocket _ = _connection.CompleteAsync(); //Notify the client that these guilds are available again - foreach (var guild in State.Guilds) + foreach (var guild in StateManager.Guilds) { if (guild.IsAvailable) await GuildAvailableAsync(guild).ConfigureAwait(false); @@ -953,10 +946,10 @@ namespace Discord.WebSocket _lastGuildAvailableTime = Environment.TickCount; await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_AVAILABLE)").ConfigureAwait(false); - var guild = State.GetGuild(data.Id); + var guild = StateManager.GetGuild(data.Id); if (guild != null) { - guild.Update(State, data); + guild.Update(StateManager, data); if (_unavailableGuildCount != 0) _unavailableGuildCount--; @@ -978,7 +971,7 @@ namespace Discord.WebSocket { await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_CREATE)").ConfigureAwait(false); - var guild = AddGuild(data, State); + var guild = AddGuild(data, StateManager); if (guild != null) { await TimedInvokeAsync(_joinedGuildEvent, nameof(JoinedGuild), guild).ConfigureAwait(false); @@ -997,11 +990,11 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_UPDATE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.Id); + var guild = StateManager.GetGuild(data.Id); if (guild != null) { var before = guild.Clone(); - guild.Update(State, data); + guild.Update(StateManager, data); await TimedInvokeAsync(_guildUpdatedEvent, nameof(GuildUpdated), before, guild).ConfigureAwait(false); } else @@ -1016,11 +1009,11 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_EMOJIS_UPDATE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { var before = guild.Clone(); - guild.Update(State, data); + guild.Update(StateManager, data); await TimedInvokeAsync(_guildUpdatedEvent, nameof(GuildUpdated), before, guild).ConfigureAwait(false); } else @@ -1061,7 +1054,7 @@ namespace Discord.WebSocket type = "GUILD_UNAVAILABLE"; await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_UNAVAILABLE)").ConfigureAwait(false); - var guild = State.GetGuild(data.Id); + var guild = StateManager.GetGuild(data.Id); if (guild != null) { await GuildUnavailableAsync(guild).ConfigureAwait(false); @@ -1098,7 +1091,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild == null) { @@ -1156,10 +1149,10 @@ namespace Discord.WebSocket SocketChannel channel = null; if (data.GuildId.IsSpecified) { - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild != null) { - channel = guild.AddChannel(State, data); + channel = guild.AddChannel(StateManager, data); if (!guild.IsSynced) { @@ -1175,10 +1168,10 @@ namespace Discord.WebSocket } else { - channel = State.GetChannel(data.Id); + channel = StateManager.GetChannel(data.Id); if (channel != null) return; //Discord may send duplicate CHANNEL_CREATEs for DMs - channel = AddPrivateChannel(data, State) as SocketChannel; + channel = AddPrivateChannel(data, StateManager) as SocketChannel; } if (channel != null) @@ -1190,11 +1183,11 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (CHANNEL_UPDATE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var channel = State.GetChannel(data.Id); + var channel = StateManager.GetChannel(data.Id); if (channel != null) { var before = channel.Clone(); - channel.Update(State, data); + channel.Update(StateManager, data); var guild = (channel as SocketGuildChannel)?.Guild; if (!(guild?.IsSynced ?? true)) @@ -1220,10 +1213,10 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); if (data.GuildId.IsSpecified) { - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild != null) { - channel = guild.RemoveChannel(State, data.Id); + channel = guild.RemoveChannel(StateManager, data.Id); if (!guild.IsSynced) { @@ -1257,7 +1250,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_MEMBER_ADD)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { var user = guild.AddOrUpdateUser(data); @@ -1283,7 +1276,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_MEMBER_UPDATE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { var user = guild.GetUser(data.User.Id); @@ -1297,13 +1290,13 @@ namespace Discord.WebSocket if (user != null) { var before = user.Clone(); - if (user.GlobalUser.Update(State, data.User)) + if (user.GlobalUser.Update(StateManager, data.User)) { //Global data was updated, trigger UserUpdated await TimedInvokeAsync(_userUpdatedEvent, nameof(UserUpdated), before.GlobalUser, user).ConfigureAwait(false); } - user.Update(State, data); + user.Update(StateManager, data); var cacheableBefore = new Cacheable(before, user.Id, true, () => null); await TimedInvokeAsync(_guildMemberUpdatedEvent, nameof(GuildMemberUpdated), cacheableBefore, user).ConfigureAwait(false); @@ -1327,7 +1320,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_MEMBER_REMOVE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { SocketUser user = guild.RemoveUser(data.User.Id); @@ -1339,12 +1332,12 @@ namespace Discord.WebSocket return; } - user ??= State.GetUser(data.User.Id); + user ??= StateManager.GetUser(data.User.Id); if (user != null) - user.Update(State, data.User); + user.Update(StateManager, data.User); else - user = State.GetOrAddUser(data.User.Id, (x) => SocketGlobalUser.Create(this, State, data.User)); + user = StateManager.GetOrAddUser(data.User.Id, (x) => SocketGlobalUser.Create(this, StateManager, data.User)); await TimedInvokeAsync(_userLeftEvent, nameof(UserLeft), guild, user).ConfigureAwait(false); } @@ -1360,7 +1353,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_MEMBERS_CHUNK)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { foreach (var memberModel in data.Members) @@ -1385,7 +1378,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild == null) { @@ -1410,7 +1403,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (CHANNEL_RECIPIENT_ADD)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - if (State.GetChannel(data.ChannelId) is SocketGroupChannel channel) + if (StateManager.GetChannel(data.ChannelId) is SocketGroupChannel channel) { var user = channel.GetOrAddUser(data.User); await TimedInvokeAsync(_recipientAddedEvent, nameof(RecipientAdded), user).ConfigureAwait(false); @@ -1427,7 +1420,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (CHANNEL_RECIPIENT_REMOVE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - if (State.GetChannel(data.ChannelId) is SocketGroupChannel channel) + if (StateManager.GetChannel(data.ChannelId) is SocketGroupChannel channel) { var user = channel.RemoveUser(data.User.Id); if (user != null) @@ -1454,7 +1447,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_ROLE_CREATE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { var role = guild.AddRole(data.Role); @@ -1478,14 +1471,14 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_ROLE_UPDATE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { var role = guild.GetRole(data.Role.Id); if (role != null) { var before = role.Clone(); - role.Update(State, data.Role); + role.Update(StateManager, data.Role); if (!guild.IsSynced) { @@ -1513,7 +1506,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_ROLE_DELETE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { var role = guild.RemoveRole(data.RoleId); @@ -1548,7 +1541,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_BAN_ADD)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { if (!guild.IsSynced) @@ -1559,7 +1552,7 @@ namespace Discord.WebSocket SocketUser user = guild.GetUser(data.User.Id); if (user == null) - user = SocketUnknownUser.Create(this, State, data.User); + user = SocketUnknownUser.Create(this, StateManager, data.User); await TimedInvokeAsync(_userBannedEvent, nameof(UserBanned), user, guild).ConfigureAwait(false); } else @@ -1574,7 +1567,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_BAN_REMOVE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { if (!guild.IsSynced) @@ -1583,9 +1576,9 @@ namespace Discord.WebSocket return; } - SocketUser user = State.GetUser(data.User.Id); + SocketUser user = StateManager.GetUser(data.User.Id); if (user == null) - user = SocketUnknownUser.Create(this, State, data.User); + user = SocketUnknownUser.Create(this, StateManager, data.User); await TimedInvokeAsync(_userUnbannedEvent, nameof(UserUnbanned), user, guild).ConfigureAwait(false); } else @@ -1616,7 +1609,7 @@ namespace Discord.WebSocket { if (!data.GuildId.IsSpecified) // assume it is a DM { - channel = CreateDMChannel(data.ChannelId, data.Author.Value, State); + channel = CreateDMChannel(data.ChannelId, data.Author.Value, StateManager); } else { @@ -1629,7 +1622,7 @@ namespace Discord.WebSocket if (guild != null) { if (data.WebhookId.IsSpecified) - author = SocketWebhookUser.Create(guild, State, data.Author.Value, data.WebhookId.Value); + author = SocketWebhookUser.Create(guild, StateManager, data.Author.Value, data.WebhookId.Value); else author = guild.GetUser(data.Author.Value.Id); } @@ -1657,7 +1650,7 @@ namespace Discord.WebSocket } } - var msg = SocketMessage.Create(this, State, author, channel, data); + var msg = SocketMessage.Create(this, StateManager, author, channel, data); SocketChannelHelper.AddMessage(channel, this, msg); await TimedInvokeAsync(_messageReceivedEvent, nameof(MessageReceived), msg).ConfigureAwait(false); } @@ -1682,7 +1675,7 @@ namespace Discord.WebSocket if (isCached) { before = cachedMsg.Clone(); - cachedMsg.Update(State, data); + cachedMsg.Update(StateManager, data); after = cachedMsg; } else @@ -1694,7 +1687,7 @@ namespace Discord.WebSocket if (guild != null) { if (data.WebhookId.IsSpecified) - author = SocketWebhookUser.Create(guild, State, data.Author.Value, data.WebhookId.Value); + author = SocketWebhookUser.Create(guild, StateManager, data.Author.Value, data.WebhookId.Value); else author = guild.GetUser(data.Author.Value.Id); } @@ -1727,12 +1720,12 @@ namespace Discord.WebSocket { if (data.Author.IsSpecified) { - var dmChannel = CreateDMChannel(data.ChannelId, data.Author.Value, State); + var dmChannel = CreateDMChannel(data.ChannelId, data.Author.Value, StateManager); channel = dmChannel; author = dmChannel.Recipient; } else - channel = CreateDMChannel(data.ChannelId, author, State); + channel = CreateDMChannel(data.ChannelId, author, StateManager); } else { @@ -1741,7 +1734,7 @@ namespace Discord.WebSocket } } - after = SocketMessage.Create(this, State, author, channel, data); + after = SocketMessage.Create(this, StateManager, author, channel, data); } var cacheableBefore = new Cacheable(before, data.Id, isCached, async () => await channel.GetMessageAsync(data.Id).ConfigureAwait(false)); @@ -1941,7 +1934,7 @@ namespace Discord.WebSocket if (data.GuildId.IsSpecified) { - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild == null) { await UnknownGuildAsync(type, data.GuildId.Value).ConfigureAwait(false); @@ -1960,12 +1953,12 @@ namespace Discord.WebSocket { return; } - user = guild.AddOrUpdateUser(data); + user = guild.AddOrUpdateUser(data.User); } else { var globalBefore = user.GlobalUser.Clone(); - if (user.GlobalUser.Update(State, data.User)) + if (user.GlobalUser.Update(StateManager, data.User)) { //Global data was updated, trigger UserUpdated await TimedInvokeAsync(_userUpdatedEvent, nameof(UserUpdated), globalBefore, user).ConfigureAwait(false); @@ -1974,7 +1967,7 @@ namespace Discord.WebSocket } else { - user = State.GetUser(data.User.Id); + user = StateManager.GetUser(data.User.Id); if (user == null) { await UnknownGlobalUserAsync(type, data.User.Id).ConfigureAwait(false); @@ -1982,10 +1975,11 @@ namespace Discord.WebSocket } } - var before = user.Presence?.Clone(); - user.Update(State, data.User); - user.Update(data); - await TimedInvokeAsync(_presenceUpdated, nameof(PresenceUpdated), user, before, user.Presence).ConfigureAwait(false); + var before = user.Presence?.Value?.Clone(); + user.Update(StateManager, data.User); + var after = SocketPresence.Create(data); + StateManager.AddOrUpdatePresence(after); + await TimedInvokeAsync(_presenceUpdated, nameof(PresenceUpdated), user, before, after).ConfigureAwait(false); } break; case "TYPING_START": @@ -2028,7 +2022,7 @@ namespace Discord.WebSocket if (!data.GuildId.IsSpecified) return; - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild != null) { @@ -2057,7 +2051,7 @@ namespace Discord.WebSocket if (!data.GuildId.IsSpecified) return; - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild != null) { @@ -2082,7 +2076,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild != null) { @@ -2112,7 +2106,7 @@ namespace Discord.WebSocket if (data.Id == CurrentUser.Id) { var before = CurrentUser.Clone(); - CurrentUser.Update(State, data); + CurrentUser.Update(StateManager, data); await TimedInvokeAsync(_selfUpdatedEvent, nameof(CurrentUserUpdated), before, CurrentUser).ConfigureAwait(false); } else @@ -2134,7 +2128,7 @@ namespace Discord.WebSocket SocketVoiceState before, after; if (data.GuildId != null) { - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild == null) { await UnknownGuildAsync(type, data.GuildId.Value).ConfigureAwait(false); @@ -2149,7 +2143,7 @@ namespace Discord.WebSocket if (data.ChannelId != null) { before = guild.GetVoiceState(data.UserId)?.Clone() ?? SocketVoiceState.Default; - after = await guild.AddOrUpdateVoiceStateAsync(State, data).ConfigureAwait(false); + after = await guild.AddOrUpdateVoiceStateAsync(StateManager, data).ConfigureAwait(false); /*if (data.UserId == CurrentUser.Id) { var _ = guild.FinishJoinAudioChannel().ConfigureAwait(false); @@ -2181,7 +2175,7 @@ namespace Discord.WebSocket if (data.ChannelId != null) { before = groupChannel.GetVoiceState(data.UserId)?.Clone() ?? SocketVoiceState.Default; - after = groupChannel.AddOrUpdateVoiceState(State, data); + after = groupChannel.AddOrUpdateVoiceState(StateManager, data); } else { @@ -2198,7 +2192,7 @@ namespace Discord.WebSocket if (user is SocketGuildUser guildUser && data.ChannelId.HasValue) { - SocketStageChannel stage = guildUser.Guild.GetStageChannel(data.ChannelId.Value); + SocketStageChannel stage = guildUser.Guild.Value.GetStageChannel(data.ChannelId.Value); if (stage != null && before.VoiceChannel != null && after.VoiceChannel != null) { @@ -2227,10 +2221,10 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (VOICE_SERVER_UPDATE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); var isCached = guild != null; var cachedGuild = new Cacheable(guild, data.GuildId, isCached, - () => Task.FromResult(State.GetGuild(data.GuildId) as IGuild)); + () => Task.FromResult(StateManager.GetGuild(data.GuildId) as IGuild)); var voiceServer = new SocketVoiceServer(cachedGuild, data.Endpoint, data.Token); await TimedInvokeAsync(_voiceServerUpdatedEvent, nameof(UserVoiceStateUpdated), voiceServer).ConfigureAwait(false); @@ -2261,7 +2255,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (INVITE_CREATE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - if (State.GetChannel(data.ChannelId) is SocketGuildChannel channel) + if (StateManager.GetChannel(data.ChannelId) is SocketGuildChannel channel) { var guild = channel.Guild; if (!guild.IsSynced) @@ -2275,7 +2269,7 @@ namespace Discord.WebSocket : null; SocketUser target = data.TargetUser.IsSpecified - ? (guild.GetUser(data.TargetUser.Value.Id) ?? (SocketUser)SocketUnknownUser.Create(this, State, data.TargetUser.Value)) + ? (guild.GetUser(data.TargetUser.Value.Id) ?? (SocketUser)SocketUnknownUser.Create(this, StateManager, data.TargetUser.Value)) : null; var invite = SocketInvite.Create(this, guild, channel, inviter, target, data); @@ -2294,7 +2288,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (INVITE_DELETE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - if (State.GetChannel(data.ChannelId) is SocketGuildChannel channel) + if (StateManager.GetChannel(data.ChannelId) is SocketGuildChannel channel) { var guild = channel.Guild; if (!guild.IsSynced) @@ -2330,19 +2324,19 @@ namespace Discord.WebSocket } SocketUser user = data.User.IsSpecified - ? State.GetOrAddUser(data.User.Value.Id, (_) => SocketGlobalUser.Create(this, State, data.User.Value)) + ? StateManager.GetOrAddUser(data.User.Value.Id, (_) => SocketGlobalUser.Create(this, StateManager, data.User.Value)) : guild?.AddOrUpdateUser(data.Member.Value); // null if the bot scope isn't set, so the guild cannot be retrieved. SocketChannel channel = null; if(data.ChannelId.IsSpecified) { - channel = State.GetChannel(data.ChannelId.Value); + channel = StateManager.GetChannel(data.ChannelId.Value); if (channel == null) { if (!data.GuildId.IsSpecified) // assume it is a DM { - channel = CreateDMChannel(data.ChannelId.Value, user, State); + channel = CreateDMChannel(data.ChannelId.Value, user, StateManager); } else { @@ -2357,7 +2351,7 @@ namespace Discord.WebSocket } else if (data.User.IsSpecified) { - channel = State.GetDMChannel(data.User.Value.Id); + channel = StateManager.GetDMChannel(data.User.Value.Id); } var interaction = SocketInteraction.Create(this, data, channel as ISocketMessageChannel, user); @@ -2398,7 +2392,7 @@ namespace Discord.WebSocket if (data.GuildId.IsSpecified) { - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild == null) { await UnknownGuildAsync(type, data.GuildId.Value).ConfigureAwait(false); @@ -2408,7 +2402,7 @@ namespace Discord.WebSocket var applicationCommand = SocketApplicationCommand.Create(this, data); - State.AddCommand(applicationCommand); + StateManager.AddCommand(applicationCommand); await TimedInvokeAsync(_applicationCommandCreated, nameof(ApplicationCommandCreated), applicationCommand).ConfigureAwait(false); } @@ -2421,7 +2415,7 @@ namespace Discord.WebSocket if (data.GuildId.IsSpecified) { - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild == null) { await UnknownGuildAsync(type, data.GuildId.Value).ConfigureAwait(false); @@ -2431,7 +2425,7 @@ namespace Discord.WebSocket var applicationCommand = SocketApplicationCommand.Create(this, data); - State.AddCommand(applicationCommand); + StateManager.AddCommand(applicationCommand); await TimedInvokeAsync(_applicationCommandUpdated, nameof(ApplicationCommandUpdated), applicationCommand).ConfigureAwait(false); } @@ -2444,7 +2438,7 @@ namespace Discord.WebSocket if (data.GuildId.IsSpecified) { - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild == null) { await UnknownGuildAsync(type, data.GuildId.Value).ConfigureAwait(false); @@ -2454,7 +2448,7 @@ namespace Discord.WebSocket var applicationCommand = SocketApplicationCommand.Create(this, data); - State.RemoveCommand(applicationCommand.Id); + StateManager.RemoveCommand(applicationCommand.Id); await TimedInvokeAsync(_applicationCommandDeleted, nameof(ApplicationCommandDeleted), applicationCommand).ConfigureAwait(false); } @@ -2468,7 +2462,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild == null) { @@ -2480,14 +2474,14 @@ namespace Discord.WebSocket if ((threadChannel = guild.ThreadChannels.FirstOrDefault(x => x.Id == data.Id)) != null) { - threadChannel.Update(State, data); + threadChannel.Update(StateManager, data); if(data.ThreadMember.IsSpecified) threadChannel.AddOrUpdateThreadMember(data.ThreadMember.Value, guild.CurrentUser); } else { - threadChannel = (SocketThreadChannel)guild.AddChannel(State, data); + threadChannel = (SocketThreadChannel)guild.AddChannel(StateManager, data); if (data.ThreadMember.IsSpecified) threadChannel.AddOrUpdateThreadMember(data.ThreadMember.Value, guild.CurrentUser); } @@ -2501,7 +2495,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Dispatch (THREAD_UPDATE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if (guild == null) { await UnknownGuildAsync(type, data.GuildId.Value); @@ -2515,7 +2509,7 @@ namespace Discord.WebSocket if (threadChannel != null) { - threadChannel.Update(State, data); + threadChannel.Update(StateManager, data); if (data.ThreadMember.IsSpecified) threadChannel.AddOrUpdateThreadMember(data.ThreadMember.Value, guild.CurrentUser); @@ -2523,7 +2517,7 @@ namespace Discord.WebSocket else { //Thread is updated but was not cached, likely meaning the thread was unarchived. - threadChannel = (SocketThreadChannel)guild.AddChannel(State, data); + threadChannel = (SocketThreadChannel)guild.AddChannel(StateManager, data); if (data.ThreadMember.IsSpecified) threadChannel.AddOrUpdateThreadMember(data.ThreadMember.Value, guild.CurrentUser); } @@ -2543,7 +2537,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId.Value); + var guild = StateManager.GetGuild(data.GuildId.Value); if(guild == null) { @@ -2564,7 +2558,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if(guild == null) { @@ -2578,11 +2572,11 @@ namespace Discord.WebSocket if(entity == null) { - entity = (SocketThreadChannel)guild.AddChannel(State, thread); + entity = (SocketThreadChannel)guild.AddChannel(StateManager, thread); } else { - entity.Update(State, thread); + entity.Update(StateManager, thread); } foreach(var member in data.Members.Where(x => x.Id.Value == entity.Id)) @@ -2600,7 +2594,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var thread = (SocketThreadChannel)State.GetChannel(data.Id.Value); + var thread = (SocketThreadChannel)StateManager.GetChannel(data.Id.Value); if (thread == null) { @@ -2618,7 +2612,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild == null) { @@ -2691,7 +2685,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if(guild == null) { @@ -2734,7 +2728,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild == null) { @@ -2753,7 +2747,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild == null) { @@ -2784,7 +2778,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if (guild == null) { @@ -2803,7 +2797,7 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); - var guild = State.GetGuild(data.GuildId); + var guild = StateManager.GetGuild(data.GuildId); if(guild == null) { @@ -2819,7 +2813,7 @@ namespace Discord.WebSocket return; } - var user = (SocketUser)guild.GetUser(data.UserId) ?? State.GetUser(data.UserId); + var user = (SocketUser)guild.GetUser(data.UserId) ?? StateManager.GetUser(data.UserId); var cacheableUser = new Cacheable(user, data.UserId, user != null, () => Rest.GetUserAsync(data.UserId)); @@ -2954,7 +2948,7 @@ namespace Discord.WebSocket await ApiClient.SendGuildSyncAsync(guildIds).ConfigureAwait(false); } - internal SocketGuild AddGuild(ExtendedGuild model, ClientState state) + internal SocketGuild AddGuild(ExtendedGuild model, ClientStateManager state) { var guild = SocketGuild.Create(this, state, model); state.AddGuild(guild); @@ -2963,26 +2957,26 @@ namespace Discord.WebSocket return guild; } internal SocketGuild RemoveGuild(ulong id) - => State.RemoveGuild(id); + => StateManager.RemoveGuild(id); /// Unexpected channel type is created. - internal ISocketPrivateChannel AddPrivateChannel(API.Channel model, ClientState state) + internal ISocketPrivateChannel AddPrivateChannel(API.Channel model, ClientStateManager state) { var channel = SocketChannel.CreatePrivate(this, state, model); state.AddChannel(channel as SocketChannel); return channel; } - internal SocketDMChannel CreateDMChannel(ulong channelId, API.User model, ClientState state) + internal SocketDMChannel CreateDMChannel(ulong channelId, API.User model, ClientStateManager state) { return SocketDMChannel.Create(this, state, channelId, model); } - internal SocketDMChannel CreateDMChannel(ulong channelId, SocketUser user, ClientState state) + internal SocketDMChannel CreateDMChannel(ulong channelId, SocketUser user, ClientStateManager state) { return new SocketDMChannel(this, channelId, user); } internal ISocketPrivateChannel RemovePrivateChannel(ulong id) { - var channel = State.RemoveChannel(id) as ISocketPrivateChannel; + var channel = StateManager.RemoveChannel(id) as ISocketPrivateChannel; if (channel != null) { foreach (var recipient in channel.Recipients) @@ -2992,8 +2986,8 @@ namespace Discord.WebSocket } internal void RemoveDMChannels() { - var channels = State.DMChannels; - State.PurgeDMChannels(); + var channels = StateManager.DMChannels; + StateManager.PurgeDMChannels(); foreach (var channel in channels) channel.Recipient.GlobalUser.RemoveRef(this); } diff --git a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs index 4cd64dbc2..21d84ba3b 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs @@ -25,6 +25,12 @@ namespace Discord.WebSocket /// public class DiscordSocketConfig : DiscordRestConfig { + /// + /// Gets or sets the cache provider to use + /// + public ICacheProvider CacheProvider { get; set; } + public IStateProvider StateProvider { get; set; } + /// /// Returns the encoding gateway should use. /// @@ -193,6 +199,11 @@ namespace Discord.WebSocket /// public bool SuppressUnknownDispatchWarnings { get; set; } = true; + /// + /// Gets or sets the default state behavior clients will use. + /// + public StateBehavior DefaultStateBehavior { get; set; } = StateBehavior.Default; + /// /// Initializes a new instance of the class with the default configuration. /// diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketCategoryChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketCategoryChannel.cs index 43f23de1a..fc5b2bb2d 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketCategoryChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketCategoryChannel.cs @@ -37,7 +37,7 @@ namespace Discord.WebSocket : base(discord, id, guild) { } - internal new static SocketCategoryChannel Create(SocketGuild guild, ClientState state, Model model) + internal new static SocketCategoryChannel Create(SocketGuild guild, ClientStateManager state, Model model) { var entity = new SocketCategoryChannel(guild.Discord, model.Id, guild); entity.Update(state, model); diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketChannel.cs index c30b3d254..0ee1e9a98 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketChannel.cs @@ -29,7 +29,7 @@ namespace Discord.WebSocket } /// Unexpected channel type is created. - internal static ISocketPrivateChannel CreatePrivate(DiscordSocketClient discord, ClientState state, Model model) + internal static ISocketPrivateChannel CreatePrivate(DiscordSocketClient discord, ClientStateManager state, Model model) { return model.Type switch { @@ -38,7 +38,7 @@ namespace Discord.WebSocket _ => throw new InvalidOperationException($"Unexpected channel type: {model.Type}"), }; } - internal abstract void Update(ClientState state, Model model); + internal abstract void Update(ClientStateManager state, Model model); #endregion #region User diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketDMChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketDMChannel.cs index 17ab4ebe3..755fa7ab3 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketDMChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketDMChannel.cs @@ -35,23 +35,23 @@ namespace Discord.WebSocket { Recipient = recipient; } - internal static SocketDMChannel Create(DiscordSocketClient discord, ClientState state, Model model) + internal static SocketDMChannel Create(DiscordSocketClient discord, ClientStateManager state, Model model) { var entity = new SocketDMChannel(discord, model.Id, discord.GetOrCreateTemporaryUser(state, model.Recipients.Value[0])); entity.Update(state, model); return entity; } - internal override void Update(ClientState state, Model model) + internal override void Update(ClientStateManager state, Model model) { Recipient.Update(state, model.Recipients.Value[0]); } - internal static SocketDMChannel Create(DiscordSocketClient discord, ClientState state, ulong channelId, API.User recipient) + internal static SocketDMChannel Create(DiscordSocketClient discord, ClientStateManager state, ulong channelId, API.User recipient) { var entity = new SocketDMChannel(discord, channelId, discord.GetOrCreateTemporaryUser(state, recipient)); entity.Update(state, recipient); return entity; } - internal void Update(ClientState state, API.User recipient) + internal void Update(ClientStateManager state, API.User recipient) { Recipient.Update(state, recipient); } diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketGroupChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketGroupChannel.cs index 4f068cf81..f6736245d 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketGroupChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketGroupChannel.cs @@ -55,13 +55,13 @@ namespace Discord.WebSocket _voiceStates = new ConcurrentDictionary(ConcurrentHashSet.DefaultConcurrencyLevel, 5); _users = new ConcurrentDictionary(ConcurrentHashSet.DefaultConcurrencyLevel, 5); } - internal static SocketGroupChannel Create(DiscordSocketClient discord, ClientState state, Model model) + internal static SocketGroupChannel Create(DiscordSocketClient discord, ClientStateManager state, Model model) { var entity = new SocketGroupChannel(discord, model.Id); entity.Update(state, model); return entity; } - internal override void Update(ClientState state, Model model) + internal override void Update(ClientStateManager state, Model model) { if (model.Name.IsSpecified) Name = model.Name.Value; @@ -73,7 +73,7 @@ namespace Discord.WebSocket RTCRegion = model.RTCRegion.GetValueOrDefault(null); } - private void UpdateUsers(ClientState state, UserModel[] models) + private void UpdateUsers(ClientStateManager state, UserModel[] models) { var users = new ConcurrentDictionary(ConcurrentHashSet.DefaultConcurrencyLevel, (int)(models.Length * 1.05)); for (int i = 0; i < models.Length; i++) @@ -265,7 +265,7 @@ namespace Discord.WebSocket return user; else { - var privateUser = SocketGroupUser.Create(this, Discord.State, model); + var privateUser = SocketGroupUser.Create(this, Discord.StateManager, model); privateUser.GlobalUser.AddRef(); _users[privateUser.Id] = privateUser; return privateUser; @@ -283,7 +283,7 @@ namespace Discord.WebSocket #endregion #region Voice States - internal SocketVoiceState AddOrUpdateVoiceState(ClientState state, VoiceStateModel model) + internal SocketVoiceState AddOrUpdateVoiceState(ClientStateManager state, VoiceStateModel model) { var voiceChannel = state.GetChannel(model.ChannelId.Value) as SocketVoiceChannel; var voiceState = SocketVoiceState.Create(voiceChannel, model); diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketGuildChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketGuildChannel.cs index 79f02fe1c..2d6e4c273 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketGuildChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketGuildChannel.cs @@ -49,7 +49,7 @@ namespace Discord.WebSocket { Guild = guild; } - internal static SocketGuildChannel Create(SocketGuild guild, ClientState state, Model model) + internal static SocketGuildChannel Create(SocketGuild guild, ClientStateManager state, Model model) { return model.Type switch { @@ -63,7 +63,7 @@ namespace Discord.WebSocket }; } /// - internal override void Update(ClientState state, Model model) + internal override void Update(ClientStateManager state, Model model) { Name = model.Name.Value; Position = model.Position.GetValueOrDefault(0); diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketNewsChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketNewsChannel.cs index eed8f9374..56d035da6 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketNewsChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketNewsChannel.cs @@ -21,7 +21,7 @@ namespace Discord.WebSocket :base(discord, id, guild) { } - internal new static SocketNewsChannel Create(SocketGuild guild, ClientState state, Model model) + internal new static SocketNewsChannel Create(SocketGuild guild, ClientStateManager state, Model model) { var entity = new SocketNewsChannel(guild.Discord, model.Id, guild); entity.Update(state, model); diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketStageChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketStageChannel.cs index 91bca5054..d98a31ff2 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketStageChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketStageChannel.cs @@ -43,7 +43,7 @@ namespace Discord.WebSocket internal SocketStageChannel(DiscordSocketClient discord, ulong id, SocketGuild guild) : base(discord, id, guild) { } - internal new static SocketStageChannel Create(SocketGuild guild, ClientState state, Model model) + internal new static SocketStageChannel Create(SocketGuild guild, ClientStateManager state, Model model) { var entity = new SocketStageChannel(guild.Discord, model.Id, guild); entity.Update(state, model); diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketTextChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketTextChannel.cs index e4a299edc..5aecf11c1 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketTextChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketTextChannel.cs @@ -63,13 +63,13 @@ namespace Discord.WebSocket if (Discord.MessageCacheSize > 0) _messages = new MessageCache(Discord); } - internal new static SocketTextChannel Create(SocketGuild guild, ClientState state, Model model) + internal new static SocketTextChannel Create(SocketGuild guild, ClientStateManager state, Model model) { var entity = new SocketTextChannel(guild.Discord, model.Id, guild); entity.Update(state, model); return entity; } - internal override void Update(ClientState state, Model model) + internal override void Update(ClientStateManager state, Model model) { base.Update(state, model); CategoryId = model.CategoryId; @@ -117,7 +117,7 @@ namespace Discord.WebSocket { var model = await ThreadHelper.CreateThreadAsync(Discord, this, name, type, autoArchiveDuration, message, invitable, slowmode, options); - var thread = (SocketThreadChannel)Guild.AddOrUpdateChannel(Discord.State, model); + var thread = (SocketThreadChannel)Guild.AddOrUpdateChannel(Discord.StateManager, model); if(Discord.AlwaysDownloadUsers && Discord.HasGatewayIntent(GatewayIntents.GuildMembers)) await thread.DownloadUsersAsync(); diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketThreadChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketThreadChannel.cs index 78462b062..4ff39e5e5 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketThreadChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketThreadChannel.cs @@ -118,7 +118,7 @@ namespace Discord.WebSocket CreatedAt = createdAt ?? new DateTimeOffset(2022, 1, 9, 0, 0, 0, TimeSpan.Zero); } - internal new static SocketThreadChannel Create(SocketGuild guild, ClientState state, Model model) + internal new static SocketThreadChannel Create(SocketGuild guild, ClientStateManager state, Model model) { var parent = guild.GetChannel(model.CategoryId.Value); var entity = new SocketThreadChannel(guild.Discord, guild, model.Id, parent, model.ThreadMetadata.GetValueOrDefault()?.CreatedAt.GetValueOrDefault(null)); @@ -126,7 +126,7 @@ namespace Discord.WebSocket return entity; } - internal override void Update(ClientState state, Model model) + internal override void Update(ClientStateManager state, Model model) { base.Update(state, model); diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketVoiceChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketVoiceChannel.cs index 00003d4ed..d684ffa9f 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketVoiceChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketVoiceChannel.cs @@ -55,14 +55,14 @@ namespace Discord.WebSocket : base(discord, id, guild) { } - internal new static SocketVoiceChannel Create(SocketGuild guild, ClientState state, Model model) + internal new static SocketVoiceChannel Create(SocketGuild guild, ClientStateManager state, Model model) { var entity = new SocketVoiceChannel(guild.Discord, model.Id, guild); entity.Update(state, model); return entity; } /// - internal override void Update(ClientState state, Model model) + internal override void Update(ClientStateManager state, Model model) { base.Update(state, model); CategoryId = model.CategoryId; diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs index 49d2cd3bd..1a16e8a25 100644 --- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs +++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs @@ -14,11 +14,11 @@ using ChannelModel = Discord.API.Channel; using EmojiUpdateModel = Discord.API.Gateway.GuildEmojiUpdateEvent; using ExtendedModel = Discord.API.Gateway.ExtendedGuild; using GuildSyncModel = Discord.API.Gateway.GuildSyncEvent; -using MemberModel = Discord.API.GuildMember; +using MemberModel = Discord.IMemberModel; using Model = Discord.API.Guild; using PresenceModel = Discord.API.Presence; using RoleModel = Discord.API.Role; -using UserModel = Discord.API.User; +using UserModel = Discord.IUserModel; using VoiceStateModel = Discord.API.VoiceState; using StickerModel = Discord.API.Sticker; using EventModel = Discord.API.GuildScheduledEvent; @@ -38,7 +38,7 @@ namespace Discord.WebSocket private TaskCompletionSource _syncPromise, _downloaderPromise; private TaskCompletionSource _audioConnectPromise; private ConcurrentDictionary _channels; - private ConcurrentDictionary _members; + //private ConcurrentDictionary _members; private ConcurrentDictionary _roles; private ConcurrentDictionary _voiceStates; private ConcurrentDictionary _stickers; @@ -305,7 +305,7 @@ namespace Discord.WebSocket /// /// Gets the current logged-in user. /// - public SocketGuildUser CurrentUser => _members.TryGetValue(Discord.CurrentUser.Id, out SocketGuildUser member) ? member : null; + public SocketGuildUser CurrentUser => Discord.StateManager.GetMember(Discord.CurrentUser.Id, Id); /// /// Gets the built-in role containing all users in this guild. /// @@ -324,7 +324,7 @@ namespace Discord.WebSocket get { var channels = _channels; - var state = Discord.State; + var state = Discord.StateManager; return channels.Select(x => x.Value).Where(x => x != null).ToReadOnlyCollection(channels); } } @@ -356,7 +356,7 @@ namespace Discord.WebSocket /// /// A collection of guild users found within this guild. /// - public IReadOnlyCollection Users => _members.ToReadOnlyCollection(); + public IReadOnlyCollection Users => Discord.StateManager.GetMembers(Id).Cast().ToImmutableArray(); /// /// Gets a collection of all roles in this guild. /// @@ -382,13 +382,13 @@ namespace Discord.WebSocket _audioLock = new SemaphoreSlim(1, 1); _emotes = ImmutableArray.Create(); } - internal static SocketGuild Create(DiscordSocketClient discord, ClientState state, ExtendedModel model) + internal static SocketGuild Create(DiscordSocketClient discord, ClientStateManager state, ExtendedModel model) { var entity = new SocketGuild(discord, model.Id); entity.Update(state, model); return entity; } - internal void Update(ClientState state, ExtendedModel model) + internal void Update(ClientStateManager state, ExtendedModel model) { IsAvailable = !(model.Unavailable ?? false); if (!IsAvailable) @@ -397,8 +397,6 @@ namespace Discord.WebSocket _events = new ConcurrentDictionary(); if (_channels == null) _channels = new ConcurrentDictionary(); - if (_members == null) - _members = new ConcurrentDictionary(); if (_roles == null) _roles = new ConcurrentDictionary(); /*if (Emojis == null) @@ -431,25 +429,6 @@ namespace Discord.WebSocket _channels = channels; - var members = new ConcurrentDictionary(ConcurrentHashSet.DefaultConcurrencyLevel, (int)(model.Members.Length * 1.05)); - { - for (int i = 0; i < model.Members.Length; i++) - { - var member = SocketGuildUser.Create(this, state, model.Members[i]); - if (members.TryAdd(member.Id, member)) - member.GlobalUser.AddRef(); - } - DownloadedMemberCount = members.Count; - - for (int i = 0; i < model.Presences.Length; i++) - { - if (members.TryGetValue(model.Presences[i].User.Id, out SocketGuildUser member)) - member.Update(state, model.Presences[i], true); - } - } - _members = members; - MemberCount = model.MemberCount; - var voiceStates = new ConcurrentDictionary(ConcurrentHashSet.DefaultConcurrencyLevel, (int)(model.VoiceStates.Length * 1.05)); { for (int i = 0; i < model.VoiceStates.Length; i++) @@ -473,6 +452,19 @@ namespace Discord.WebSocket } _events = events; + for (int i = 0; i < model.Members.Length; i++) + { + Discord.StateManager.AddOrUpdateMember(Id, SocketGuildUser.Create(Id, Discord, model.Members[i])); + } + DownloadedMemberCount = model.Members.Length; + + for (int i = 0; i < model.Presences.Length; i++) + { + Discord.StateManager.AddOrUpdatePresence(SocketPresence.Create(model.Presences[i])); + } + + MemberCount = model.MemberCount; + _syncPromise = new TaskCompletionSource(); _downloaderPromise = new TaskCompletionSource(); @@ -480,7 +472,7 @@ namespace Discord.WebSocket /*if (!model.Large) _ = _downloaderPromise.TrySetResultAsync(true);*/ } - internal void Update(ClientState state, Model model) + internal void Update(ClientStateManager state, Model model) { AFKChannelId = model.AFKChannelId; if (model.WidgetChannelId.IsSpecified) @@ -561,7 +553,7 @@ namespace Discord.WebSocket else _stickers = new ConcurrentDictionary(ConcurrentHashSet.DefaultConcurrencyLevel, 7); } - /*internal void Update(ClientState state, GuildSyncModel model) //TODO remove? userbot related + /*internal void Update(ClientStateManager state, GuildSyncModel model) //TODO remove? userbot related { var members = new ConcurrentDictionary(ConcurrentHashSet.DefaultConcurrencyLevel, (int)(model.Members.Length * 1.05)); { @@ -585,7 +577,7 @@ namespace Discord.WebSocket // _ = _downloaderPromise.TrySetResultAsync(true); }*/ - internal void Update(ClientState state, EmojiUpdateModel model) + internal void Update(ClientStateManager state, EmojiUpdateModel model) { var emotes = ImmutableArray.CreateBuilder(model.Emojis.Length); for (int i = 0; i < model.Emojis.Length; i++) @@ -682,7 +674,7 @@ namespace Discord.WebSocket /// public SocketGuildChannel GetChannel(ulong id) { - var channel = Discord.State.GetChannel(id) as SocketGuildChannel; + var channel = Discord.StateManager.GetChannel(id) as SocketGuildChannel; if (channel?.Guild.Id == Id) return channel; return null; @@ -799,7 +791,7 @@ namespace Discord.WebSocket public Task CreateCategoryChannelAsync(string name, Action func = null, RequestOptions options = null) => GuildHelper.CreateCategoryChannelAsync(this, Discord, name, options, func); - internal SocketGuildChannel AddChannel(ClientState state, ChannelModel model) + internal SocketGuildChannel AddChannel(ClientStateManager state, ChannelModel model) { var channel = SocketGuildChannel.Create(this, state, model); _channels.TryAdd(model.Id, channel); @@ -807,26 +799,26 @@ namespace Discord.WebSocket return channel; } - internal SocketGuildChannel AddOrUpdateChannel(ClientState state, ChannelModel model) + internal SocketGuildChannel AddOrUpdateChannel(ClientStateManager state, ChannelModel model) { if (_channels.TryGetValue(model.Id, out SocketGuildChannel channel)) - channel.Update(Discord.State, model); + channel.Update(Discord.StateManager, model); else { - channel = SocketGuildChannel.Create(this, Discord.State, model); + channel = SocketGuildChannel.Create(this, Discord.StateManager, model); _channels[channel.Id] = channel; state.AddChannel(channel); } return channel; } - internal SocketGuildChannel RemoveChannel(ClientState state, ulong id) + internal SocketGuildChannel RemoveChannel(ClientStateManager state, ulong id) { if (_channels.TryRemove(id, out var _)) return state.RemoveChannel(id) as SocketGuildChannel; return null; } - internal void PurgeChannelCache(ClientState state) + internal void PurgeChannelCache(ClientStateManager state) { foreach (var channelId in _channels) state.RemoveChannel(channelId.Key); @@ -880,7 +872,7 @@ namespace Discord.WebSocket foreach (var command in commands) { - Discord.State.AddCommand(command); + Discord.StateManager.AddCommand(command); } return commands.ToImmutableArray(); @@ -898,7 +890,7 @@ namespace Discord.WebSocket /// public async ValueTask GetApplicationCommandAsync(ulong id, CacheMode mode = CacheMode.AllowDownload, RequestOptions options = null) { - var command = Discord.State.GetCommand(id); + var command = Discord.StateManager.GetCommand(id); if (command != null) return command; @@ -913,7 +905,7 @@ namespace Discord.WebSocket command = SocketApplicationCommand.Create(Discord, model, Id); - Discord.State.AddCommand(command); + Discord.StateManager.AddCommand(command); return command; } @@ -930,7 +922,7 @@ namespace Discord.WebSocket { var model = await InteractionHelper.CreateGuildCommandAsync(Discord, Id, properties, options); - var entity = Discord.State.GetOrAddCommand(model.Id, (id) => SocketApplicationCommand.Create(Discord, model)); + var entity = Discord.StateManager.GetOrAddCommand(model.Id, (id) => SocketApplicationCommand.Create(Discord, model)); entity.Update(model); @@ -952,11 +944,11 @@ namespace Discord.WebSocket var entities = models.Select(x => SocketApplicationCommand.Create(Discord, x)); - Discord.State.PurgeCommands(x => !x.IsGlobalCommand && x.Guild.Id == Id); + Discord.StateManager.PurgeCommands(x => !x.IsGlobalCommand && x.Guild.Id == Id); foreach(var entity in entities) { - Discord.State.AddCommand(entity); + Discord.StateManager.AddCommand(entity); } return entities.ToImmutableArray(); @@ -1020,7 +1012,7 @@ namespace Discord.WebSocket => GuildHelper.CreateRoleAsync(this, Discord, name, permissions, color, isHoisted, isMentionable, options); internal SocketRole AddRole(RoleModel model) { - var role = SocketRole.Create(this, Discord.State, model); + var role = SocketRole.Create(this, Discord.StateManager, model); _roles[model.Id] = role; return role; } @@ -1034,7 +1026,7 @@ namespace Discord.WebSocket internal SocketRole AddOrUpdateRole(RoleModel model) { if (_roles.TryGetValue(model.Id, out SocketRole role)) - _roles[model.Id].Update(Discord.State, model); + _roles[model.Id].Update(Discord.StateManager, model); else role = AddRole(model); @@ -1089,60 +1081,45 @@ namespace Discord.WebSocket /// A guild user associated with the specified ; if none is found. /// public SocketGuildUser GetUser(ulong id) - { - if (_members.TryGetValue(id, out SocketGuildUser member)) - return member; - return null; - } + => Discord.StateManager.GetMember(id, Id); /// public Task PruneUsersAsync(int days = 30, bool simulate = false, RequestOptions options = null, IEnumerable includeRoleIds = null) => GuildHelper.PruneUsersAsync(this, Discord, days, simulate, options, includeRoleIds); internal SocketGuildUser AddOrUpdateUser(UserModel model) { - if (_members.TryGetValue(model.Id, out SocketGuildUser member)) - member.GlobalUser?.Update(Discord.State, model); + SocketGuildUser member; + if ((member = GetUser(model.Id)) != null) + member.GlobalUser?.Update(Discord.StateManager, model); else { - member = SocketGuildUser.Create(this, Discord.State, model); + member = SocketGuildUser.Create(Id, Discord, model); member.GlobalUser.AddRef(); - _members[member.Id] = member; DownloadedMemberCount++; } return member; } internal SocketGuildUser AddOrUpdateUser(MemberModel model) { - if (_members.TryGetValue(model.User.Id, out SocketGuildUser member)) - member.Update(Discord.State, model); + SocketGuildUser member; + if ((member = GetUser(model.User.Id)) != null) + member.Update(Discord.StateManager, model); else { - member = SocketGuildUser.Create(this, Discord.State, model); + member = SocketGuildUser.Create(Id, Discord, model); member.GlobalUser.AddRef(); - _members[member.Id] = member; - DownloadedMemberCount++; - } - return member; - } - internal SocketGuildUser AddOrUpdateUser(PresenceModel model) - { - if (_members.TryGetValue(model.User.Id, out SocketGuildUser member)) - member.Update(Discord.State, model, false); - else - { - member = SocketGuildUser.Create(this, Discord.State, model); - member.GlobalUser.AddRef(); - _members[member.Id] = member; DownloadedMemberCount++; } return member; } internal SocketGuildUser RemoveUser(ulong id) { - if (_members.TryRemove(id, out SocketGuildUser member)) + SocketGuildUser member; + if ((member = GetUser(id)) != null) { DownloadedMemberCount--; member.GlobalUser.RemoveRef(Discord); + Discord.StateManager.RemoveMember(id, Id); return member; } return null; @@ -1158,18 +1135,16 @@ namespace Discord.WebSocket /// The predicate used to select which users to clear. public void PurgeUserCache(Func predicate) { - var membersToPurge = Users.Where(x => predicate.Invoke(x) && x?.Id != Discord.CurrentUser.Id); - var membersToKeep = Users.Where(x => !predicate.Invoke(x) || x?.Id == Discord.CurrentUser.Id); + var users = Users.ToArray(); - foreach (var member in membersToPurge) - if(_members.TryRemove(member.Id, out _)) - member.GlobalUser.RemoveRef(Discord); + var membersToPurge = users.Where(x => predicate.Invoke(x) && x?.Id != Discord.CurrentUser.Id); + var membersToKeep = users.Where(x => !predicate.Invoke(x) || x?.Id == Discord.CurrentUser.Id); - foreach (var member in membersToKeep) - _members.TryAdd(member.Id, member); + foreach (var member in membersToPurge) + Discord.StateManager.RemoveMember(member.Id, Id); _downloaderPromise = new TaskCompletionSource(); - DownloadedMemberCount = _members.Count; + DownloadedMemberCount = membersToKeep.Count(); } /// @@ -1537,7 +1512,7 @@ namespace Discord.WebSocket #endregion #region Voice States - internal async Task AddOrUpdateVoiceStateAsync(ClientState state, VoiceStateModel model) + internal async Task AddOrUpdateVoiceStateAsync(ClientStateManager state, VoiceStateModel model) { var voiceChannel = state.GetChannel(model.ChannelId.Value) as SocketVoiceChannel; var before = GetVoiceState(model.UserId) ?? SocketVoiceState.Default; diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuildEvent.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuildEvent.cs index a86aafadf..9f019cdb1 100644 --- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuildEvent.cs +++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuildEvent.cs @@ -89,13 +89,13 @@ namespace Discord.WebSocket if(guildUser != null) { if(model.Creator.IsSpecified) - guildUser.Update(Discord.State, model.Creator.Value); + guildUser.Update(Discord.StateManager, model.Creator.Value); Creator = guildUser; } else if (guildUser == null && model.Creator.IsSpecified) { - guildUser = SocketGuildUser.Create(Guild, Discord.State, model.Creator.Value); + guildUser = SocketGuildUser.Create(Guild.Id, Discord, model.Creator.Value); Creator = guildUser; } } diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs b/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs index aeff465bd..28a922e65 100644 --- a/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs +++ b/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs @@ -56,18 +56,18 @@ namespace Discord.WebSocket if (Channel is SocketGuildChannel channel) { if (model.Message.Value.WebhookId.IsSpecified) - author = SocketWebhookUser.Create(channel.Guild, Discord.State, model.Message.Value.Author.Value, model.Message.Value.WebhookId.Value); + author = SocketWebhookUser.Create(channel.Guild, Discord.StateManager, model.Message.Value.Author.Value, model.Message.Value.WebhookId.Value); else if (model.Message.Value.Author.IsSpecified) author = channel.Guild.GetUser(model.Message.Value.Author.Value.Id); } else if (model.Message.Value.Author.IsSpecified) author = (Channel as SocketChannel).GetUser(model.Message.Value.Author.Value.Id); - Message = SocketUserMessage.Create(Discord, Discord.State, author, Channel, model.Message.Value); + Message = SocketUserMessage.Create(Discord, Discord.StateManager, author, Channel, model.Message.Value); } else { - Message.Update(Discord.State, model.Message.Value); + Message.Update(Discord.StateManager, model.Message.Value); } } } diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketResolvableData.cs b/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketResolvableData.cs index d722c5a13..d36960749 100644 --- a/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketResolvableData.cs +++ b/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketResolvableData.cs @@ -29,7 +29,7 @@ namespace Discord.WebSocket { foreach (var user in resolved.Users.Value) { - var socketUser = discord.GetOrCreateUser(discord.State, user.Value); + var socketUser = discord.GetOrCreateUser(discord.StateManager, user.Value); Users.Add(ulong.Parse(user.Key), socketUser); } @@ -50,11 +50,11 @@ namespace Discord.WebSocket : discord.Rest.ApiClient.GetChannelAsync(channel.Value.Id).ConfigureAwait(false).GetAwaiter().GetResult(); socketChannel = guild != null - ? SocketGuildChannel.Create(guild, discord.State, channelModel) - : (SocketChannel)SocketChannel.CreatePrivate(discord, discord.State, channelModel); + ? SocketGuildChannel.Create(guild, discord.StateManager, channelModel) + : (SocketChannel)SocketChannel.CreatePrivate(discord, discord.StateManager, channelModel); } - discord.State.AddChannel(socketChannel); + discord.StateManager.AddChannel(socketChannel); Channels.Add(ulong.Parse(channel.Key), socketChannel); } } @@ -88,7 +88,7 @@ namespace Discord.WebSocket if (guild != null) { if (msg.Value.WebhookId.IsSpecified) - author = SocketWebhookUser.Create(guild, discord.State, msg.Value.Author.Value, msg.Value.WebhookId.Value); + author = SocketWebhookUser.Create(guild, discord.StateManager, msg.Value.Author.Value, msg.Value.WebhookId.Value); else author = guild.GetUser(msg.Value.Author.Value.Id); } @@ -99,11 +99,11 @@ namespace Discord.WebSocket { if (!msg.Value.GuildId.IsSpecified) // assume it is a DM { - channel = discord.CreateDMChannel(msg.Value.ChannelId, msg.Value.Author.Value, discord.State); + channel = discord.CreateDMChannel(msg.Value.ChannelId, msg.Value.Author.Value, discord.StateManager); } } - var message = SocketMessage.Create(discord, discord.State, author, channel, msg.Value); + var message = SocketMessage.Create(discord, discord.StateManager, author, channel, msg.Value); Messages.Add(message.Id, message); } } diff --git a/src/Discord.Net.WebSocket/Entities/Messages/SocketMessage.cs b/src/Discord.Net.WebSocket/Entities/Messages/SocketMessage.cs index 6668426e1..51a691b6f 100644 --- a/src/Discord.Net.WebSocket/Entities/Messages/SocketMessage.cs +++ b/src/Discord.Net.WebSocket/Entities/Messages/SocketMessage.cs @@ -129,7 +129,7 @@ namespace Discord.WebSocket Author = author; Source = source; } - internal static SocketMessage Create(DiscordSocketClient discord, ClientState state, SocketUser author, ISocketMessageChannel channel, Model model) + internal static SocketMessage Create(DiscordSocketClient discord, ClientStateManager state, SocketUser author, ISocketMessageChannel channel, Model model) { if (model.Type == MessageType.Default || model.Type == MessageType.Reply || @@ -140,7 +140,7 @@ namespace Discord.WebSocket else return SocketSystemMessage.Create(discord, state, author, channel, model); } - internal virtual void Update(ClientState state, Model model) + internal virtual void Update(ClientStateManager state, Model model) { Type = model.Type; diff --git a/src/Discord.Net.WebSocket/Entities/Messages/SocketSystemMessage.cs b/src/Discord.Net.WebSocket/Entities/Messages/SocketSystemMessage.cs index ec22a7703..50fbec4b7 100644 --- a/src/Discord.Net.WebSocket/Entities/Messages/SocketSystemMessage.cs +++ b/src/Discord.Net.WebSocket/Entities/Messages/SocketSystemMessage.cs @@ -13,13 +13,13 @@ namespace Discord.WebSocket : base(discord, id, channel, author, MessageSource.System) { } - internal new static SocketSystemMessage Create(DiscordSocketClient discord, ClientState state, SocketUser author, ISocketMessageChannel channel, Model model) + internal new static SocketSystemMessage Create(DiscordSocketClient discord, ClientStateManager state, SocketUser author, ISocketMessageChannel channel, Model model) { var entity = new SocketSystemMessage(discord, model.Id, channel, author); entity.Update(state, model); return entity; } - internal override void Update(ClientState state, Model model) + internal override void Update(ClientStateManager state, Model model) { base.Update(state, model); } diff --git a/src/Discord.Net.WebSocket/Entities/Messages/SocketUserMessage.cs b/src/Discord.Net.WebSocket/Entities/Messages/SocketUserMessage.cs index e5776a089..94c081d75 100644 --- a/src/Discord.Net.WebSocket/Entities/Messages/SocketUserMessage.cs +++ b/src/Discord.Net.WebSocket/Entities/Messages/SocketUserMessage.cs @@ -53,14 +53,14 @@ namespace Discord.WebSocket : base(discord, id, channel, author, source) { } - internal new static SocketUserMessage Create(DiscordSocketClient discord, ClientState state, SocketUser author, ISocketMessageChannel channel, Model model) + internal new static SocketUserMessage Create(DiscordSocketClient discord, ClientStateManager state, SocketUser author, ISocketMessageChannel channel, Model model) { var entity = new SocketUserMessage(discord, model.Id, channel, author, MessageHelper.GetSource(model)); entity.Update(state, model); return entity; } - internal override void Update(ClientState state, Model model) + internal override void Update(ClientStateManager state, Model model) { base.Update(state, model); diff --git a/src/Discord.Net.WebSocket/Entities/Roles/SocketRole.cs b/src/Discord.Net.WebSocket/Entities/Roles/SocketRole.cs index 1e90b8f5c..fdadaa44a 100644 --- a/src/Discord.Net.WebSocket/Entities/Roles/SocketRole.cs +++ b/src/Discord.Net.WebSocket/Entities/Roles/SocketRole.cs @@ -67,13 +67,13 @@ namespace Discord.WebSocket { Guild = guild; } - internal static SocketRole Create(SocketGuild guild, ClientState state, Model model) + internal static SocketRole Create(SocketGuild guild, ClientStateManager state, Model model) { var entity = new SocketRole(guild, model.Id); entity.Update(state, model); return entity; } - internal void Update(ClientState state, Model model) + internal void Update(ClientStateManager state, Model model) { Name = model.Name; IsHoisted = model.Hoist; diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketGlobalUser.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketGlobalUser.cs index 236e7d432..41eadcc4c 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketGlobalUser.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketGlobalUser.cs @@ -1,18 +1,17 @@ using System; using System.Diagnostics; using System.Linq; -using Model = Discord.API.User; +using Model = Discord.IUserModel; namespace Discord.WebSocket { [DebuggerDisplay(@"{DebuggerDisplay,nq}")] - internal class SocketGlobalUser : SocketUser + internal class SocketGlobalUser : SocketUser, IDisposable { public override bool IsBot { get; internal set; } public override string Username { get; internal set; } public override ushort DiscriminatorValue { get; internal set; } public override string AvatarId { get; internal set; } - internal override SocketPresence Presence { get; set; } public override bool IsWebhook => false; internal override SocketGlobalUser GlobalUser { get => this; set => throw new NotImplementedException(); } @@ -23,8 +22,9 @@ namespace Discord.WebSocket private SocketGlobalUser(DiscordSocketClient discord, ulong id) : base(discord, id) { + } - internal static SocketGlobalUser Create(DiscordSocketClient discord, ClientState state, Model model) + internal static SocketGlobalUser Create(DiscordSocketClient discord, ClientStateManager state, Model model) { var entity = new SocketGlobalUser(discord, model.Id); entity.Update(state, model); @@ -48,6 +48,9 @@ namespace Discord.WebSocket } } + ~SocketGlobalUser() => Discord.StateManager.RemoveReferencedGlobalUser(Id); + public void Dispose() => Discord.StateManager.RemoveReferencedGlobalUser(Id); + private string DebuggerDisplay => $"{Username}#{Discriminator} ({Id}{(IsBot ? ", Bot" : "")}, Global)"; internal new SocketGlobalUser Clone() => MemberwiseClone() as SocketGlobalUser; } diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketGroupUser.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketGroupUser.cs index a40ae59be..9d5fb0ef8 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketGroupUser.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketGroupUser.cs @@ -30,7 +30,7 @@ namespace Discord.WebSocket /// public override string AvatarId { get { return GlobalUser.AvatarId; } internal set { GlobalUser.AvatarId = value; } } /// - internal override SocketPresence Presence { get { return GlobalUser.Presence; } set { GlobalUser.Presence = value; } } + internal override Lazy Presence { get { return GlobalUser.Presence; } set { GlobalUser.Presence = value; } } /// public override bool IsWebhook => false; @@ -41,7 +41,7 @@ namespace Discord.WebSocket Channel = channel; GlobalUser = globalUser; } - internal static SocketGroupUser Create(SocketGroupChannel channel, ClientState state, Model model) + internal static SocketGroupUser Create(SocketGroupChannel channel, ClientStateManager state, Model model) { var entity = new SocketGroupUser(channel, channel.Discord.GetOrCreateUser(state, model)); entity.Update(state, model); diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketGuildUser.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketGuildUser.cs index 051687b78..293013938 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketGuildUser.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketGuildUser.cs @@ -6,9 +6,9 @@ using System.Collections.Immutable; using System.Diagnostics; using System.Linq; using System.Threading.Tasks; -using UserModel = Discord.API.User; -using MemberModel = Discord.API.GuildMember; -using PresenceModel = Discord.API.Presence; +using UserModel = Discord.IUserModel; +using MemberModel = Discord.IMemberModel; +using PresenceModel = Discord.IPresenceModel; namespace Discord.WebSocket { @@ -16,19 +16,24 @@ namespace Discord.WebSocket /// Represents a WebSocket-based guild user. /// [DebuggerDisplay(@"{DebuggerDisplay,nq}")] - public class SocketGuildUser : SocketUser, IGuildUser + public class SocketGuildUser : SocketUser, IGuildUser, ICached, IDisposable { #region SocketGuildUser private long? _premiumSinceTicks; private long? _timedOutTicks; private long? _joinedAtTicks; private ImmutableArray _roleIds; + private ulong _guildId; internal override SocketGlobalUser GlobalUser { get; set; } /// /// Gets the guild the user is in. /// - public SocketGuild Guild { get; } + public Lazy Guild { get; } + /// + /// Gets the guilds id that the user is in. + /// + public ulong GuildId => _guildId; /// public string DisplayName => Nickname ?? Username; /// @@ -47,8 +52,7 @@ namespace Discord.WebSocket public override string AvatarId { get { return GlobalUser.AvatarId; } internal set { GlobalUser.AvatarId = value; } } /// - public GuildPermissions GuildPermissions => new GuildPermissions(Permissions.ResolveGuild(Guild, this)); - internal override SocketPresence Presence { get; set; } + public GuildPermissions GuildPermissions => new GuildPermissions(Permissions.ResolveGuild(Guild.Value, this)); /// public override bool IsWebhook => false; @@ -78,7 +82,7 @@ namespace Discord.WebSocket /// Returns a collection of roles that the user possesses. /// public IReadOnlyCollection Roles - => _roleIds.Select(id => Guild.GetRole(id)).Where(x => x != null).ToReadOnlyCollection(() => _roleIds.Length); + => _roleIds.Select(id => Guild.Value.GetRole(id)).Where(x => x != null).ToReadOnlyCollection(() => _roleIds.Length); /// /// Returns the voice channel the user is in, or null if none. /// @@ -92,8 +96,8 @@ namespace Discord.WebSocket /// A representing the user's voice status; null if the user is not /// connected to a voice channel. /// - public SocketVoiceState? VoiceState => Guild.GetVoiceState(Id); - public AudioInStream AudioStream => Guild.GetAudioStream(Id); + public SocketVoiceState? VoiceState => Guild.Value.GetVoiceState(Id); + public AudioInStream AudioStream => Guild.Value.GetAudioStream(Id); /// public DateTimeOffset? PremiumSince => DateTimeUtils.FromTicks(_premiumSinceTicks); /// @@ -119,13 +123,13 @@ namespace Discord.WebSocket { get { - if (Guild.OwnerId == Id) + if (Guild.Value.OwnerId == Id) return int.MaxValue; int maxPos = 0; for (int i = 0; i < _roleIds.Length; i++) { - var role = Guild.GetRole(_roleIds[i]); + var role = Guild.Value.GetRole(_roleIds[i]); if (role != null && role.Position > maxPos) maxPos = role.Position; } @@ -133,79 +137,46 @@ namespace Discord.WebSocket } } - internal SocketGuildUser(SocketGuild guild, SocketGlobalUser globalUser) - : base(guild.Discord, globalUser.Id) + internal SocketGuildUser(ulong guildId, SocketGlobalUser globalUser, DiscordSocketClient client) + : base(client, globalUser.Id) { - Guild = guild; + _guildId = guildId; + Guild = new Lazy(() => client.StateManager.GetGuild(_guildId), System.Threading.LazyThreadSafetyMode.PublicationOnly); GlobalUser = globalUser; } - internal static SocketGuildUser Create(SocketGuild guild, ClientState state, UserModel model) - { - var entity = new SocketGuildUser(guild, guild.Discord.GetOrCreateUser(state, model)); - entity.Update(state, model); - entity.UpdateRoles(new ulong[0]); - return entity; - } - internal static SocketGuildUser Create(SocketGuild guild, ClientState state, MemberModel model) + internal static SocketGuildUser Create(ulong guildId, DiscordSocketClient client, UserModel model) { - var entity = new SocketGuildUser(guild, guild.Discord.GetOrCreateUser(state, model.User)); - entity.Update(state, model); - if (!model.Roles.IsSpecified) - entity.UpdateRoles(new ulong[0]); + var entity = new SocketGuildUser(guildId, client.GetOrCreateUser(client.StateManager, (Discord.API.User)model), client); + if (entity.Update(client.StateManager, model)) + client.StateManager.AddOrUpdateMember(guildId, entity); + entity.UpdateRoles(Array.Empty()); return entity; } - internal static SocketGuildUser Create(SocketGuild guild, ClientState state, PresenceModel model) + internal static SocketGuildUser Create(ulong guildId, DiscordSocketClient client, MemberModel model) { - var entity = new SocketGuildUser(guild, guild.Discord.GetOrCreateUser(state, model.User)); - entity.Update(state, model, false); - if (!model.Roles.IsSpecified) - entity.UpdateRoles(new ulong[0]); + var entity = new SocketGuildUser(guildId, client.GetOrCreateUser(client.StateManager, model.User), client); + entity.Update(client.StateManager, model); + client.StateManager.AddOrUpdateMember(guildId, entity); return entity; } - internal void Update(ClientState state, MemberModel model) + internal void Update(ClientStateManager state, MemberModel model) { base.Update(state, model.User); - if (model.JoinedAt.IsSpecified) - _joinedAtTicks = model.JoinedAt.Value.UtcTicks; - if (model.Nick.IsSpecified) - Nickname = model.Nick.Value; - if (model.Avatar.IsSpecified) - GuildAvatarId = model.Avatar.Value; - if (model.Roles.IsSpecified) - UpdateRoles(model.Roles.Value); - if (model.PremiumSince.IsSpecified) - _premiumSinceTicks = model.PremiumSince.Value?.UtcTicks; - if (model.TimedOutUntil.IsSpecified) - _timedOutTicks = model.TimedOutUntil.Value?.UtcTicks; - if (model.Pending.IsSpecified) - IsPending = model.Pending.Value; - } - internal void Update(ClientState state, PresenceModel model, bool updatePresence) - { - if (updatePresence) - { - Update(model); - } - if (model.Nick.IsSpecified) - Nickname = model.Nick.Value; - if (model.Roles.IsSpecified) - UpdateRoles(model.Roles.Value); - if (model.PremiumSince.IsSpecified) - _premiumSinceTicks = model.PremiumSince.Value?.UtcTicks; - } - - internal override void Update(PresenceModel model) - { - Presence ??= new SocketPresence(); - Presence.Update(model); - GlobalUser.Update(model); + _joinedAtTicks = model.JoinedAt.UtcTicks; + Nickname = model.Nickname; + GuildAvatarId = model.GuildAvatar; + UpdateRoles(model.Roles); + if (model.PremiumSince.HasValue) + _premiumSinceTicks = model.PremiumSince.Value.UtcTicks; + if (model.CommunicationsDisabledUntil.HasValue) + _timedOutTicks = model.CommunicationsDisabledUntil.Value.UtcTicks; + IsPending = model.IsPending.GetValueOrDefault(false); } - private void UpdateRoles(ulong[] roleIds) { var roles = ImmutableArray.CreateBuilder(roleIds.Length + 1); - roles.Add(Guild.Id); + roles.Add(_guildId); for (int i = 0; i < roleIds.Length; i++) roles.Add(roleIds[i]); _roleIds = roles.ToImmutable(); @@ -249,7 +220,7 @@ namespace Discord.WebSocket => UserHelper.RemoveTimeOutAsync(this, Discord, options); /// public ChannelPermissions GetPermissions(IGuildChannel channel) - => new ChannelPermissions(Permissions.ResolveChannel(Guild, this, channel, GuildPermissions.RawValue)); + => new ChannelPermissions(Permissions.ResolveChannel(Guild.Value, this, channel, GuildPermissions.RawValue)); /// public string GetDisplayAvatarUrl(ImageFormat format = ImageFormat.Auto, ushort size = 128) @@ -259,7 +230,7 @@ namespace Discord.WebSocket /// public string GetGuildAvatarUrl(ImageFormat format = ImageFormat.Auto, ushort size = 128) - => CDN.GetGuildUserAvatarUrl(Id, Guild.Id, GuildAvatarId, size, format); + => CDN.GetGuildUserAvatarUrl(Id, _guildId, GuildAvatarId, size, format); private string DebuggerDisplay => $"{Username}#{Discriminator} ({Id}{(IsBot ? ", Bot" : "")}, Guild)"; @@ -269,13 +240,14 @@ namespace Discord.WebSocket clone.GlobalUser = GlobalUser.Clone(); return clone; } + #endregion #region IGuildUser /// - IGuild IGuildUser.Guild => Guild; + IGuild IGuildUser.Guild => Guild.Value; /// - ulong IGuildUser.GuildId => Guild.Id; + ulong IGuildUser.GuildId => _guildId; /// IReadOnlyCollection IGuildUser.RoleIds => _roleIds; @@ -283,5 +255,55 @@ namespace Discord.WebSocket /// IVoiceChannel IVoiceState.VoiceChannel => VoiceChannel; #endregion + + #region Cache + + private struct CacheModel : MemberModel + { + public UserModel User { get; set; } + + public string Nickname { get; set; } + + public string GuildAvatar { get; set; } + + public ulong[] Roles { get; set; } + + public DateTimeOffset JoinedAt { get; set; } + + public DateTimeOffset? PremiumSince { get; set; } + + public bool IsDeaf { get; set; } + + public bool IsMute { get; set; } + + public bool? IsPending { get; set; } + + public DateTimeOffset? CommunicationsDisabledUntil { get; set; } + } + + MemberModel ICached.ToModel() + => ToMemberModel(); + + internal MemberModel ToMemberModel() + { + return new CacheModel + { + User = ((ICached)this).ToModel(), + CommunicationsDisabledUntil = TimedOutUntil, + GuildAvatar = GuildAvatarId, + IsDeaf = IsDeafened, + IsMute = IsMuted, + IsPending = IsPending, + JoinedAt = JoinedAt ?? DateTimeOffset.UtcNow, // review: nullable joined at here? should our model reflect this? + Nickname = Nickname, + PremiumSince = PremiumSince, + Roles = _roleIds.ToArray() + }; + } + + public void Dispose() => Discord.StateManager.RemovedReferencedMember(Id, _guildId); + ~SocketGuildUser() => Discord.StateManager.RemovedReferencedMember(Id, _guildId); + + #endregion } } diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketPresence.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketPresence.cs index 5250e15ad..e6cd61bcc 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketPresence.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketPresence.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Linq; -using Model = Discord.API.Presence; +using Model = Discord.IPresenceModel; namespace Discord.WebSocket { @@ -11,8 +11,11 @@ namespace Discord.WebSocket /// Represents the WebSocket user's presence status. This may include their online status and their activity. /// [DebuggerDisplay(@"{DebuggerDisplay,nq}")] - public class SocketPresence : IPresence + public class SocketPresence : IPresence, ICached { + internal ulong UserId; + internal ulong? GuildId; + /// public UserStatus Status { get; private set; } /// @@ -38,8 +41,10 @@ namespace Discord.WebSocket internal void Update(Model model) { Status = model.Status; - ActiveClients = ConvertClientTypesDict(model.ClientStatus.GetValueOrDefault()) ?? ImmutableArray.Empty; + ActiveClients = model.ActiveClients.Length > 0 ? model.ActiveClients.ToImmutableArray() : ImmutableArray.Empty; Activities = ConvertActivitiesList(model.Activities) ?? ImmutableArray.Empty; + UserId = model.UserId; + GuildId = model.GuildId; } /// @@ -76,9 +81,9 @@ namespace Discord.WebSocket /// /// A list of all that this user currently has available. /// - private static IImmutableList ConvertActivitiesList(IList activities) + private static IImmutableList ConvertActivitiesList(IActivityModel[] activities) { - if (activities == null || activities.Count == 0) + if (activities == null || activities.Length == 0) return ImmutableList.Empty; var list = new List(); foreach (var activity in activities) @@ -96,5 +101,61 @@ namespace Discord.WebSocket private string DebuggerDisplay => $"{Status}{(Activities?.FirstOrDefault()?.Name ?? "")}"; internal SocketPresence Clone() => MemberwiseClone() as SocketPresence; + + #region Cache + private struct CacheModel : Model + { + public UserStatus Status { get; set; } + + public ClientType[] ActiveClients { get; set; } + + public IActivityModel[] Activities { get; set; } + + public ulong UserId { get; set; } + + public ulong? GuildId { get; set; } + } + + internal Model ToModel() + { + return new CacheModel + { + Status = Status, + ActiveClients = ActiveClients.ToArray(), + UserId = UserId, + GuildId = GuildId, + Activities = Activities.Select(x => + { + switch (x) + { + case Game game: + switch (game) + { + case RichGame richGame: + return richGame.ToModel(); + case SpotifyGame spotify: + return spotify.ToModel(); + case CustomStatusGame custom: + return custom.ToModel(); + case StreamingGame stream: + return stream.ToModel(); + } + break; + } + + return new WritableActivityModel + { + Name = x.Name, + Details = x.Details, + Flags = x.Flags, + Type = x.Type + }; + }).ToArray(), + }; + } + + Model ICached.ToModel() => ToModel(); + + #endregion } } diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketSelfUser.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketSelfUser.cs index 3bde1beab..45b3ebc4f 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketSelfUser.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketSelfUser.cs @@ -2,7 +2,8 @@ using Discord.Rest; using System; using System.Diagnostics; using System.Threading.Tasks; -using Model = Discord.API.User; +using Model = Discord.ICurrentUserModel; +using UserModel = Discord.IUserModel; namespace Discord.WebSocket { @@ -10,7 +11,7 @@ namespace Discord.WebSocket /// Represents the logged-in WebSocket-based user. /// [DebuggerDisplay(@"{DebuggerDisplay,nq}")] - public class SocketSelfUser : SocketUser, ISelfUser + public class SocketSelfUser : SocketUser, ISelfUser, ICached { /// public string Email { get; private set; } @@ -29,7 +30,7 @@ namespace Discord.WebSocket /// public override string AvatarId { get { return GlobalUser.AvatarId; } internal set { GlobalUser.AvatarId = value; } } /// - internal override SocketPresence Presence { get { return GlobalUser.Presence; } set { GlobalUser.Presence = value; } } + internal override Lazy Presence { get { return GlobalUser.Presence; } set { GlobalUser.Presence = value; } } /// public UserProperties Flags { get; internal set; } /// @@ -45,43 +46,47 @@ namespace Discord.WebSocket { GlobalUser = globalUser; } - internal static SocketSelfUser Create(DiscordSocketClient discord, ClientState state, Model model) + internal static SocketSelfUser Create(DiscordSocketClient discord, ClientStateManager state, Model model) { var entity = new SocketSelfUser(discord, discord.GetOrCreateSelfUser(state, model)); entity.Update(state, model); return entity; } - internal override bool Update(ClientState state, Model model) + internal override bool Update(ClientStateManager state, UserModel model) { bool hasGlobalChanges = base.Update(state, model); - if (model.Email.IsSpecified) + + if (model is not Model currentUserModel) + throw new ArgumentException($"Got unexpected model type \"{model?.GetType()}\""); + + if(currentUserModel.Email != Email) { - Email = model.Email.Value; + Email = currentUserModel.Email; hasGlobalChanges = true; } - if (model.Verified.IsSpecified) + if (currentUserModel.IsVerified.HasValue) { - IsVerified = model.Verified.Value; + IsVerified = currentUserModel.IsVerified.Value; hasGlobalChanges = true; } - if (model.MfaEnabled.IsSpecified) + if (currentUserModel.IsMfaEnabled.HasValue) { - IsMfaEnabled = model.MfaEnabled.Value; + IsMfaEnabled = currentUserModel.IsMfaEnabled.Value; hasGlobalChanges = true; } - if (model.Flags.IsSpecified && model.Flags.Value != Flags) + if (currentUserModel.Flags != Flags) { - Flags = (UserProperties)model.Flags.Value; + Flags = currentUserModel.Flags; hasGlobalChanges = true; } - if (model.PremiumType.IsSpecified && model.PremiumType.Value != PremiumType) + if (currentUserModel.PremiumType != PremiumType) { - PremiumType = model.PremiumType.Value; + PremiumType = currentUserModel.PremiumType; hasGlobalChanges = true; } - if (model.Locale.IsSpecified && model.Locale.Value != Locale) + if (currentUserModel.Locale != Locale) { - Locale = model.Locale.Value; + Locale = currentUserModel.Locale; hasGlobalChanges = true; } return hasGlobalChanges; @@ -93,5 +98,55 @@ namespace Discord.WebSocket private string DebuggerDisplay => $"{Username}#{Discriminator} ({Id}{(IsBot ? ", Bot" : "")}, Self)"; internal new SocketSelfUser Clone() => MemberwiseClone() as SocketSelfUser; + + #region Cache + + private struct CacheModel : Model + { + public bool? IsVerified { get; set; } + + public string Email { get; set; } + + public bool? IsMfaEnabled { get; set; } + + public UserProperties Flags { get; set; } + + public PremiumType PremiumType { get; set; } + + public string Locale { get; set; } + + public UserProperties PublicFlags { get; set; } + + public string Username { get; set; } + + public string Discriminator { get; set; } + + public bool? IsBot { get; set; } + + public string Avatar { get; set; } + + public ulong Id { get; set; } + } + + Model ICached.ToModel() + { + return new CacheModel + { + Avatar = AvatarId, + Discriminator = Discriminator, + Email = Email, + Flags = Flags, + Id = Id, + IsBot = IsBot, + IsMfaEnabled = IsMfaEnabled, + IsVerified = IsVerified, + Locale = Locale, + PremiumType = this.PremiumType, + PublicFlags = PublicFlags ?? UserProperties.None, + Username = Username + }; + } + + #endregion } } diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketThreadUser.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketThreadUser.cs index 6eddd876d..e42805e4e 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketThreadUser.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketThreadUser.cs @@ -227,7 +227,7 @@ namespace Discord.WebSocket internal override SocketGlobalUser GlobalUser { get => GuildUser.GlobalUser; set => GuildUser.GlobalUser = value; } - internal override SocketPresence Presence { get => GuildUser.Presence; set => GuildUser.Presence = value; } + internal override Lazy Presence { get => GuildUser.Presence; set => GuildUser.Presence = value; } /// /// Gets the guild user of this thread user. diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketUnknownUser.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketUnknownUser.cs index 99c47696a..5d2ddef32 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketUnknownUser.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketUnknownUser.cs @@ -26,7 +26,7 @@ namespace Discord.WebSocket /// public override bool IsWebhook => false; /// - internal override SocketPresence Presence { get { return new SocketPresence(UserStatus.Offline, null, null); } set { } } + internal override Lazy Presence { get { return new Lazy(() => new SocketPresence(UserStatus.Offline, null, null)); } set { } } /// /// This field is not supported for an unknown user. internal override SocketGlobalUser GlobalUser { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } @@ -35,7 +35,7 @@ namespace Discord.WebSocket : base(discord, id) { } - internal static SocketUnknownUser Create(DiscordSocketClient discord, ClientState state, Model model) + internal static SocketUnknownUser Create(DiscordSocketClient discord, ClientStateManager state, Model model) { var entity = new SocketUnknownUser(discord, model.Id); entity.Update(state, model); diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketUser.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketUser.cs index d70e61739..fca36184b 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketUser.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketUser.cs @@ -6,8 +6,8 @@ using System.Globalization; using System.Linq; using System.Threading.Tasks; using Discord.Rest; -using Model = Discord.API.User; -using PresenceModel = Discord.API.Presence; +using Model = Discord.IUserModel; +using PresenceModel = Discord.IPresenceModel; namespace Discord.WebSocket { @@ -15,7 +15,7 @@ namespace Discord.WebSocket /// Represents a WebSocket-based user. /// [DebuggerDisplay(@"{DebuggerDisplay,nq}")] - public abstract class SocketUser : SocketEntity, IUser + public abstract class SocketUser : SocketEntity, IUser, ICached { /// public abstract bool IsBot { get; internal set; } @@ -30,7 +30,7 @@ namespace Discord.WebSocket /// public UserProperties? PublicFlags { get; private set; } internal abstract SocketGlobalUser GlobalUser { get; set; } - internal abstract SocketPresence Presence { get; set; } + internal virtual Lazy Presence { get; set; } /// public DateTimeOffset CreatedAt => SnowflakeUtils.FromSnowflake(Id); @@ -39,11 +39,11 @@ namespace Discord.WebSocket /// public string Mention => MentionUtils.MentionUser(Id); /// - public UserStatus Status => Presence.Status; + public UserStatus Status => Presence.Value.Status; /// - public IReadOnlyCollection ActiveClients => Presence.ActiveClients ?? ImmutableHashSet.Empty; + public IReadOnlyCollection ActiveClients => Presence.Value.ActiveClients ?? ImmutableHashSet.Empty; /// - public IReadOnlyCollection Activities => Presence.Activities ?? ImmutableList.Empty; + public IReadOnlyCollection Activities => Presence.Value.Activities ?? ImmutableList.Empty; /// /// Gets mutual guilds shared with this user. /// @@ -57,46 +57,45 @@ namespace Discord.WebSocket : base(discord, id) { } - internal virtual bool Update(ClientState state, Model model) + internal virtual bool Update(ClientStateManager state, Model model) { - Presence ??= new SocketPresence(); + Presence ??= new Lazy(() => state.GetPresence(Id), System.Threading.LazyThreadSafetyMode.ExecutionAndPublication); bool hasChanges = false; - if (model.Avatar.IsSpecified && model.Avatar.Value != AvatarId) + if (model.Avatar != AvatarId) { - AvatarId = model.Avatar.Value; + AvatarId = model.Avatar; hasChanges = true; } - if (model.Discriminator.IsSpecified) + if (model.Discriminator != null) { - var newVal = ushort.Parse(model.Discriminator.Value, NumberStyles.None, CultureInfo.InvariantCulture); + var newVal = ushort.Parse(model.Discriminator, NumberStyles.None, CultureInfo.InvariantCulture); if (newVal != DiscriminatorValue) { - DiscriminatorValue = ushort.Parse(model.Discriminator.Value, NumberStyles.None, CultureInfo.InvariantCulture); + DiscriminatorValue = ushort.Parse(model.Discriminator, NumberStyles.None, CultureInfo.InvariantCulture); hasChanges = true; } } - if (model.Bot.IsSpecified && model.Bot.Value != IsBot) + if (model.IsBot.HasValue && model.IsBot.Value != IsBot) { - IsBot = model.Bot.Value; + IsBot = model.IsBot.Value; hasChanges = true; } - if (model.Username.IsSpecified && model.Username.Value != Username) + if (model.Username != Username) { - Username = model.Username.Value; + Username = model.Username; hasChanges = true; } - if (model.PublicFlags.IsSpecified && model.PublicFlags.Value != PublicFlags) + + if(model is ICurrentUserModel currentUserModel) { - PublicFlags = model.PublicFlags.Value; - hasChanges = true; + if (currentUserModel.PublicFlags != PublicFlags) + { + PublicFlags = currentUserModel.PublicFlags; + hasChanges = true; + } } - return hasChanges; - } - internal virtual void Update(PresenceModel model) - { - Presence ??= new SocketPresence(); - Presence.Update(model); + return hasChanges; } /// @@ -120,5 +119,36 @@ namespace Discord.WebSocket public override string ToString() => Format.UsernameAndDiscriminator(this, Discord.FormatUsersInBidirectionalUnicode); private string DebuggerDisplay => $"{Format.UsernameAndDiscriminator(this, Discord.FormatUsersInBidirectionalUnicode)} ({Id}{(IsBot ? ", Bot" : "")})"; internal SocketUser Clone() => MemberwiseClone() as SocketUser; + + #region Cache + private struct CacheModel : Model + { + public string Username { get; set; } + + public string Discriminator { get; set; } + + public bool? IsBot { get; set; } + + public string Avatar { get; set; } + + public ulong Id { get; set; } + } + + Model ICached.ToModel() + => ToModel(); + + internal Model ToModel() + { + return new CacheModel + { + Avatar = AvatarId, + Discriminator = Discriminator, + Id = Id, + IsBot = IsBot, + Username = Username + }; + } + + #endregion } } diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketWebhookUser.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketWebhookUser.cs index 2b2c259c5..06f9a8ab5 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketWebhookUser.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketWebhookUser.cs @@ -33,7 +33,7 @@ namespace Discord.WebSocket /// public override bool IsWebhook => true; /// - internal override SocketPresence Presence { get { return new SocketPresence(UserStatus.Offline, null, null); } set { } } + internal override Lazy Presence { get { return new Lazy(() => new SocketPresence(UserStatus.Offline, null, null)); } set { } } internal override SocketGlobalUser GlobalUser { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } internal SocketWebhookUser(SocketGuild guild, ulong id, ulong webhookId) @@ -42,7 +42,7 @@ namespace Discord.WebSocket Guild = guild; WebhookId = webhookId; } - internal static SocketWebhookUser Create(SocketGuild guild, ClientState state, Model model, ulong webhookId) + internal static SocketWebhookUser Create(SocketGuild guild, ClientStateManager state, Model model, ulong webhookId) { var entity = new SocketWebhookUser(guild, model.Id, webhookId); entity.Update(state, model); diff --git a/src/Discord.Net.WebSocket/Extensions/EntityExtensions.cs b/src/Discord.Net.WebSocket/Extensions/EntityExtensions.cs index 46f5c1a26..6cde93d87 100644 --- a/src/Discord.Net.WebSocket/Extensions/EntityExtensions.cs +++ b/src/Discord.Net.WebSocket/Extensions/EntityExtensions.cs @@ -7,86 +7,97 @@ namespace Discord.WebSocket { internal static class EntityExtensions { - public static IActivity ToEntity(this API.Game model) + public static IActivity ToEntity(this IActivityModel model) { #region Custom Status Game - if (model.Id.IsSpecified && model.Id.Value == "custom") + if (model.Id != null && model.Id == "custom") { return new CustomStatusGame() { Type = ActivityType.CustomStatus, Name = model.Name, - State = model.State.IsSpecified ? model.State.Value : null, - Emote = model.Emoji.IsSpecified ? model.Emoji.Value.ToIEmote() : null, - CreatedAt = DateTimeOffset.FromUnixTimeMilliseconds(model.CreatedAt.Value), + State = model.State, + Emote = model.Emoji?.ToIEmote(), + CreatedAt = model.CreatedAt, }; } #endregion #region Spotify Game - if (model.SyncId.IsSpecified) + if (model.SyncId != null) { - var assets = model.Assets.GetValueOrDefault()?.ToEntity(); - string albumText = assets?[1]?.Text; - string albumArtId = assets?[1]?.ImageId?.Replace("spotify:", ""); - var timestamps = model.Timestamps.IsSpecified ? model.Timestamps.Value.ToEntity() : null; + string albumText = model.LargeText; + string albumArtId = model.LargeImage?.Replace("spotify:", ""); return new SpotifyGame { Name = model.Name, - SessionId = model.SessionId.GetValueOrDefault(), - TrackId = model.SyncId.Value, - TrackUrl = CDN.GetSpotifyDirectUrl(model.SyncId.Value), + SessionId = model.SessionId, + TrackId = model.SyncId, + TrackUrl = CDN.GetSpotifyDirectUrl(model.SyncId), AlbumTitle = albumText, - TrackTitle = model.Details.GetValueOrDefault(), - Artists = model.State.GetValueOrDefault()?.Split(';').Select(x => x?.Trim()).ToImmutableArray(), - StartedAt = timestamps?.Start, - EndsAt = timestamps?.End, - Duration = timestamps?.End - timestamps?.Start, + TrackTitle = model.Details, + Artists = model.State?.Split(';').Select(x => x?.Trim()).ToImmutableArray(), + StartedAt = model.TimestampStart, + EndsAt = model.TimestampEnd, + Duration = model.TimestampEnd - model.TimestampStart, AlbumArtUrl = albumArtId != null ? CDN.GetSpotifyAlbumArtUrl(albumArtId) : null, Type = ActivityType.Listening, - Flags = model.Flags.GetValueOrDefault(), + Flags = model.Flags, + AlbumArt = model.LargeImage, }; } #endregion #region Rich Game - if (model.ApplicationId.IsSpecified) + if (model.ApplicationId.HasValue) { ulong appId = model.ApplicationId.Value; - var assets = model.Assets.GetValueOrDefault()?.ToEntity(appId); return new RichGame { ApplicationId = appId, Name = model.Name, - Details = model.Details.GetValueOrDefault(), - State = model.State.GetValueOrDefault(), - SmallAsset = assets?[0], - LargeAsset = assets?[1], - Party = model.Party.IsSpecified ? model.Party.Value.ToEntity() : null, - Secrets = model.Secrets.IsSpecified ? model.Secrets.Value.ToEntity() : null, - Timestamps = model.Timestamps.IsSpecified ? model.Timestamps.Value.ToEntity() : null, - Flags = model.Flags.GetValueOrDefault() + Details = model.Details, + State = model.State, + SmallAsset = new GameAsset + { + Text = model.SmallText, + ImageId = model.SmallImage, + ApplicationId = appId, + }, + LargeAsset = new GameAsset + { + Text = model.LargeText, + ApplicationId = appId, + ImageId = model.LargeImage + }, + Party = model.PartyId != null ? new GameParty + { + Id = model.PartyId, + Capacity = model.PartySize?.Length > 1 ? model.PartySize[1] : 0, + Members = model.PartySize?.Length > 0 ? model.PartySize[0] : 0 + } : null, + Secrets = model.JoinSecret != null || model.SpectateSecret != null || model.MatchSecret != null ? new GameSecrets(model.MatchSecret, model.JoinSecret, model.SpectateSecret) : null, + Timestamps = model.TimestampStart.HasValue || model.TimestampEnd.HasValue ? new GameTimestamps(model.TimestampStart, model.TimestampEnd) : null, + Flags = model.Flags }; } #endregion #region Stream Game - if (model.StreamUrl.IsSpecified) + if (model.Url != null) { return new StreamingGame( model.Name, - model.StreamUrl.Value) + model.Url) { - Flags = model.Flags.GetValueOrDefault(), - Details = model.Details.GetValueOrDefault() + Flags = model.Flags, + Details = model.Details }; } #endregion #region Normal Game - return new Game(model.Name, model.Type.GetValueOrDefault() ?? ActivityType.Playing, - model.Flags.IsSpecified ? model.Flags.Value : ActivityProperties.None, - model.Details.GetValueOrDefault()); + return new Game(model.Name, model.Type, model.Flags, model.Details); #endregion } diff --git a/src/Discord.Net.WebSocket/Extensions/StateExtensions.cs b/src/Discord.Net.WebSocket/Extensions/StateExtensions.cs new file mode 100644 index 000000000..7719b26c1 --- /dev/null +++ b/src/Discord.Net.WebSocket/Extensions/StateExtensions.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord.WebSocket +{ + internal static class StateExtensions + { + public static StateBehavior ToBehavior(this CacheMode mode) + { + return mode switch + { + CacheMode.AllowDownload => StateBehavior.AllowDownload, + CacheMode.CacheOnly => StateBehavior.CacheOnly, + _ => StateBehavior.AllowDownload + }; + } + } +} diff --git a/src/Discord.Net.WebSocket/Interactions/ShardedInteractionContext.cs b/src/Discord.Net.WebSocket/Interactions/ShardedInteractionContext.cs index ac0524172..195349861 100644 --- a/src/Discord.Net.WebSocket/Interactions/ShardedInteractionContext.cs +++ b/src/Discord.Net.WebSocket/Interactions/ShardedInteractionContext.cs @@ -19,13 +19,13 @@ namespace Discord.Interactions /// The underlying client. /// The underlying interaction. public ShardedInteractionContext (DiscordShardedClient client, TInteraction interaction) - : base(client.GetShard(GetShardId(client, ( interaction.User as SocketGuildUser )?.Guild)), interaction) + : base(client.GetShard(GetShardId(client, (interaction.User as SocketGuildUser )?.GuildId)), interaction) { Client = client; } - private static int GetShardId (DiscordShardedClient client, IGuild guild) - => guild == null ? 0 : client.GetShardIdFor(guild); + private static int GetShardId(DiscordShardedClient client, ulong? guildId) + => guildId.HasValue ? client.GetShardIdFor(guildId.Value) : 0; } /// diff --git a/src/Discord.Net.WebSocket/Interactions/SocketInteractionContext.cs b/src/Discord.Net.WebSocket/Interactions/SocketInteractionContext.cs index 4cd9ef264..d61068e3e 100644 --- a/src/Discord.Net.WebSocket/Interactions/SocketInteractionContext.cs +++ b/src/Discord.Net.WebSocket/Interactions/SocketInteractionContext.cs @@ -45,7 +45,7 @@ namespace Discord.Interactions { Client = client; Channel = interaction.Channel; - Guild = (interaction.User as SocketGuildUser)?.Guild; + Guild = (interaction.User as SocketGuildUser)?.Guild.Value; User = interaction.User; Interaction = interaction; } diff --git a/src/Discord.Net.WebSocket/State/DefaultStateProvider.cs b/src/Discord.Net.WebSocket/State/DefaultStateProvider.cs new file mode 100644 index 000000000..1604fce0c --- /dev/null +++ b/src/Discord.Net.WebSocket/State/DefaultStateProvider.cs @@ -0,0 +1,256 @@ +using Discord.Logging; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Discord.WebSocket +{ + internal class DefaultStateProvider : IStateProvider + { + private const double AverageChannelsPerGuild = 10.22; //Source: Googie2149 + private const double AverageUsersPerGuild = 47.78; //Source: Googie2149 + private const double CollectionMultiplier = 1.05; //Add 5% buffer to handle growth + + private readonly ICacheProvider _cache; + private readonly StateBehavior _defaultBehavior; + private readonly DiscordSocketClient _client; + private readonly Logger _logger; + public DefaultStateProvider(Logger logger, ICacheProvider cacheProvider, DiscordSocketClient client, StateBehavior stateBehavior) + { + _cache = cacheProvider; + _client = client; + _logger = logger; + + if (stateBehavior == StateBehavior.Default) + throw new ArgumentException("Cannot use \"default\" as the default state behavior"); + + _defaultBehavior = stateBehavior; + } + + private void RunAsyncWithLogs(ValueTask task) + { + _ = Task.Run(async () => + { + try + { + await task.ConfigureAwait(false); + } + catch (Exception x) + { + await _logger.ErrorAsync("Cache provider failed", x).ConfigureAwait(false); + } + }); + } + + private TResult WaitSynchronouslyForTask(Task t) + { + var sw = new SpinWait(); + while (!t.IsCompleted) + sw.SpinOnce(); + return t.GetAwaiter().GetResult(); + } + + private TType ValidateAsSocketEntity(ISnowflakeEntity entity) where TType : SocketEntity + { + if(entity is not TType val) + throw new NotSupportedException("Cannot cache non-socket entities"); + return val; + } + + private StateBehavior ResolveBehavior(StateBehavior behavior) + => behavior == StateBehavior.Default ? _defaultBehavior : behavior; + + + public ValueTask AddOrUpdateMemberAsync(ulong guildId, IGuildUser user) + { + var socketGuildUser = ValidateAsSocketEntity(user); + var model = socketGuildUser.ToMemberModel(); + RunAsyncWithLogs(_cache.AddOrUpdateMemberAsync(model, guildId, CacheRunMode.Async)); + return default; + } + public ValueTask AddOrUpdateUserAsync(IUser user) + { + var socketUser = ValidateAsSocketEntity(user); + var model = socketUser.ToModel(); + RunAsyncWithLogs(_cache.AddOrUpdateUserAsync(model, CacheRunMode.Async)); + return default; + } + public ValueTask GetMemberAsync(ulong guildId, ulong id, StateBehavior stateBehavior, RequestOptions options = null) + { + var behavior = ResolveBehavior(stateBehavior); + + var cacheMode = behavior == StateBehavior.SyncOnly ? CacheRunMode.Sync : CacheRunMode.Async; + + if(behavior != StateBehavior.DownloadOnly) + { + var memberLookupTask = _cache.GetMemberAsync(id, guildId, cacheMode); + + if (memberLookupTask.IsCompleted) + { + var model = memberLookupTask.Result; + if(model != null) + return new ValueTask(SocketGuildUser.Create(guildId, _client, model)); + } + else + { + return new ValueTask(Task.Run(async () => + { + var result = await memberLookupTask; + + if (result != null) + return (IGuildUser)SocketGuildUser.Create(guildId, _client, result); + else if (behavior == StateBehavior.AllowDownload || behavior == StateBehavior.DownloadOnly) + return await _client.Rest.GetGuildUserAsync(guildId, id, options).ConfigureAwait(false); + return null; + })); + } + } + + if (behavior == StateBehavior.AllowDownload || behavior == StateBehavior.DownloadOnly) + return new ValueTask(_client.Rest.GetGuildUserAsync(guildId, id, options).ContinueWith(x => (IGuildUser)x.Result)); + + return default; + } + + public ValueTask> GetMembersAsync(ulong guildId, StateBehavior stateBehavior, RequestOptions options = null) + { + var behavior = ResolveBehavior(stateBehavior); + + var cacheMode = behavior == StateBehavior.SyncOnly ? CacheRunMode.Sync : CacheRunMode.Async; + + if(behavior != StateBehavior.DownloadOnly) + { + var memberLookupTask = _cache.GetMembersAsync(guildId, cacheMode); + + if (memberLookupTask.IsCompleted) + return new ValueTask>(memberLookupTask.Result?.Select(x => SocketGuildUser.Create(guildId, _client, x))); + else + { + return new ValueTask>(Task.Run(async () => + { + var result = await memberLookupTask; + + if (result != null && result.Any()) + return result.Select(x => (IGuildUser)SocketGuildUser.Create(guildId, _client, x)); + + if (behavior == StateBehavior.AllowDownload || behavior == StateBehavior.DownloadOnly) + return await _client.Rest.GetGuildUsersAsync(guildId, options); + + return null; + })); + } + } + + return default; + } + + public ValueTask GetUserAsync(ulong id, StateBehavior stateBehavior, RequestOptions options = null) + { + var behavior = ResolveBehavior(stateBehavior); + + var cacheMode = behavior == StateBehavior.SyncOnly ? CacheRunMode.Sync : CacheRunMode.Async; + + if (behavior != StateBehavior.DownloadOnly) + { + var userLookupTask = _cache.GetUserAsync(id, cacheMode); + + if (userLookupTask.IsCompleted) + { + var model = userLookupTask.Result; + if(model != null) + return new ValueTask(SocketGlobalUser.Create(_client, null, model)); + } + else + { + return new ValueTask(Task.Run(async () => + { + var result = await userLookupTask; + + if (result != null) + return SocketGlobalUser.Create(_client, null, result); + + if (behavior == StateBehavior.AllowDownload || behavior == StateBehavior.DownloadOnly) + return await _client.Rest.GetUserAsync(id, options); + + return null; + })); + } + } + + if (behavior == StateBehavior.AllowDownload || behavior == StateBehavior.DownloadOnly) + return new ValueTask(_client.Rest.GetUserAsync(id, options).ContinueWith(x => (IUser)x.Result)); + + return default; + } + + public ValueTask> GetUsersAsync(StateBehavior stateBehavior, RequestOptions options = null) + { + var behavior = ResolveBehavior(stateBehavior); + + var cacheMode = behavior == StateBehavior.SyncOnly ? CacheRunMode.Sync : CacheRunMode.Async; + + if(behavior != StateBehavior.DownloadOnly) + { + var usersTask = _cache.GetUsersAsync(cacheMode); + + if (usersTask.IsCompleted) + return new ValueTask>(usersTask.Result.Select(x => (IUser)SocketGlobalUser.Create(_client, null, x))); + else + { + return new ValueTask>(usersTask.AsTask().ContinueWith(x => x.Result.Select(x => (IUser)SocketGlobalUser.Create(_client, null, x)))); + } + } + + // no download path + return default; + } + + public ValueTask RemoveMemberAsync(ulong id, ulong guildId) + => _cache.RemoveMemberAsync(id, guildId, CacheRunMode.Async); + public ValueTask RemoveUserAsync(ulong id) + => _cache.RemoveUserAsync(id, CacheRunMode.Async); + + public ValueTask GetPresenceAsync(ulong userId, StateBehavior stateBehavior) + { + var behavior = ResolveBehavior(stateBehavior); + + var cacheMode = behavior == StateBehavior.SyncOnly ? CacheRunMode.Sync : CacheRunMode.Async; + + if(stateBehavior != StateBehavior.DownloadOnly) + { + var fetchTask = _cache.GetPresenceAsync(userId, cacheMode); + + if (fetchTask.IsCompleted) + return new ValueTask(SocketPresence.Create(fetchTask.Result)); + else + { + return new ValueTask(fetchTask.AsTask().ContinueWith(x => + { + if (x.Result != null) + return (IPresence)SocketPresence.Create(x.Result); + return null; + })); + } + } + + // theres no rest call to download presence so return null + return new ValueTask((IPresence)null); + } + + public ValueTask AddOrUpdatePresenseAsync(ulong userId, IPresence presense, StateBehavior stateBehavior) + { + if (presense is not SocketPresence socketPresense) + throw new ArgumentException($"Expected socket entity but got {presense?.GetType()}"); + + var model = socketPresense.ToModel(); + + RunAsyncWithLogs(_cache.AddOrUpdatePresenseAsync(userId, model, CacheRunMode.Async)); + return default; + } + public ValueTask RemovePresenseAsync(ulong userId) + => _cache.RemovePresenseAsync(userId, CacheRunMode.Async); + } +} diff --git a/src/Discord.Net.WebSocket/State/IStateProvider.cs b/src/Discord.Net.WebSocket/State/IStateProvider.cs new file mode 100644 index 000000000..c944d9f19 --- /dev/null +++ b/src/Discord.Net.WebSocket/State/IStateProvider.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord.WebSocket +{ + public interface IStateProvider + { + ValueTask GetPresenceAsync(ulong userId, StateBehavior stateBehavior); + ValueTask AddOrUpdatePresenseAsync(ulong userId, IPresence presense, StateBehavior stateBehavior); + ValueTask RemovePresenseAsync(ulong userId); + + ValueTask GetUserAsync(ulong id, StateBehavior stateBehavior, RequestOptions options = null); + ValueTask> GetUsersAsync(StateBehavior stateBehavior, RequestOptions options = null); + ValueTask AddOrUpdateUserAsync(IUser user); + ValueTask RemoveUserAsync(ulong id); + + ValueTask GetMemberAsync(ulong guildId, ulong id, StateBehavior stateBehavior, RequestOptions options = null); + ValueTask> GetMembersAsync(ulong guildId, StateBehavior stateBehavior, RequestOptions options = null); + ValueTask AddOrUpdateMemberAsync(ulong guildId, IGuildUser user); + ValueTask RemoveMemberAsync(ulong guildId, ulong id); + } +} diff --git a/src/Discord.Net.WebSocket/State/StateBehavior.cs b/src/Discord.Net.WebSocket/State/StateBehavior.cs new file mode 100644 index 000000000..4a387d5a9 --- /dev/null +++ b/src/Discord.Net.WebSocket/State/StateBehavior.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord.WebSocket +{ + public enum StateBehavior + { + /// + /// Use the default Cache Behavior of the client. + /// + /// + Default = 0, + /// + /// The entity will only be retrieved via a synchronous cache lookup. + /// + /// For the default , this is equivalent to using + /// + /// + /// This flag is used to indicate that the retrieval of this entity should not leave the + /// synchronous path of the . When true, + /// the calling method *should* not ever leave the calling task, and never generate an async + /// state machine. + /// + /// Bear in mind that the true behavior of this flag depends entirely on the to + /// abide by design implications of this flag. Once Discord.Net has called out to the state provider with this + /// flag, it is out of our control whether or not an async method is evaluated. + /// + SyncOnly = 1, + /// + /// The entity will only be retrieved via a cache lookup - the Discord API will not be contacted to retrieve the entity. + /// + /// + /// When using an alternative , usage of this flag implies that it is + /// okay for the state provider to make an external call if the local cache missed the entity. + /// + /// Note that when designing an , this flag does not imply that the state + /// provider itself should contact Discord for the entity; rather that if using a dual-layer caching system, + /// it would be okay to contact an external layer, e.g. Redis, for the entity. + /// + CacheOnly = 2, + /// + /// The entity will be downloaded from the Discord REST API if the on hand cannot locate it. + /// + AllowDownload = 3, + /// + /// The entity will be downloaded from the Discord REST API. The local will not be contacted to find the entity. + /// + DownloadOnly = 4 + } +}