diff --git a/src/Discord.Net.Core/Cache/Models/Message/IMessageModel.cs b/src/Discord.Net.Core/Cache/Models/Message/IMessageModel.cs new file mode 100644 index 000000000..c1dc6001c --- /dev/null +++ b/src/Discord.Net.Core/Cache/Models/Message/IMessageModel.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Discord +{ + public interface IMessageModel : IEntityModel + { + MessageType Type { get; set; } + ulong ChannelId { get; set; } + ulong? GuildId { get; set; } + ulong AuthorId { get; set; } + bool IsWebhookMessage { get; set; } + string Content { get; set; } + DateTimeOffset Timestamp { get; set; } + DateTimeOffset? EditedTimestamp { get; set; } + bool IsTextToSpeech { get; set; } + bool MentionEveryone { get; set; } + ulong[] UserMentionIds { get; set; } + ulong[] RoleMentionIds { get; set; } + } +} diff --git a/src/Discord.Net.WebSocket/Cache/DefaultConcurrentCacheProvider.cs b/src/Discord.Net.WebSocket/Cache/DefaultConcurrentCacheProvider.cs index 7a146677b..f182bd1f6 100644 --- a/src/Discord.Net.WebSocket/Cache/DefaultConcurrentCacheProvider.cs +++ b/src/Discord.Net.WebSocket/Cache/DefaultConcurrentCacheProvider.cs @@ -12,16 +12,6 @@ namespace Discord.WebSocket private readonly ConcurrentDictionary _storeCache = new(); private readonly ConcurrentDictionary _subStoreCache = new(); - private readonly Dictionary _models = new() - { - { typeof(IUserModel), typeof(API.User) }, - { typeof(ICurrentUserModel), typeof(API.CurrentUser) }, - { typeof(IMemberModel), typeof(API.GuildMember) }, - { typeof(IThreadMemberModel), typeof(API.ThreadMember)}, - { typeof(IPresenceModel), typeof(API.Presence)}, - { typeof(IActivityModel), typeof(API.Game)} - }; - private class DefaultEntityStore : IEntityStore where TModel : IEntityModel where TId : IEquatable @@ -94,12 +84,7 @@ namespace Discord.WebSocket } } - public Type GetModel() - { - if (_models.TryGetValue(typeof(TInterface), out var t)) - return t; - return null; - } + public Type GetModel() => null; public virtual ValueTask> GetStoreAsync() where TModel : IEntityModel diff --git a/src/Discord.Net.WebSocket/ClientStateManager.Experiment.cs b/src/Discord.Net.WebSocket/ClientStateManager.Experiment.cs index dc63dc797..b4b91ff69 100644 --- a/src/Discord.Net.WebSocket/ClientStateManager.Experiment.cs +++ b/src/Discord.Net.WebSocket/ClientStateManager.Experiment.cs @@ -57,18 +57,24 @@ namespace Discord.WebSocket private readonly ConcurrentDictionary> _references = new(); private IEntityStore _store; private Func _entityBuilder; + private Func _modelFactory; private Func> _restLookup; - private readonly bool _allowSyncWaits; private readonly object _lock = new(); - public ReferenceStore(ICacheProvider cacheProvider, Func entityBuilder, Func> restLookup, bool allowSyncWaits) + public ReferenceStore(ICacheProvider cacheProvider, + Func entityBuilder, + Func> restLookup, + Func userDefinedModelFactory) { - _allowSyncWaits = allowSyncWaits; _cacheProvider = cacheProvider; _entityBuilder = entityBuilder; _restLookup = restLookup; + _modelFactory = userDefinedModelFactory; } + private TModel GetUserDefinedModel(TModel t) + => t.ToSpecifiedModel(_modelFactory()); + internal bool RemoveReference(TId id) { if(_references.TryGetValue(id, out var rf)) @@ -196,20 +202,23 @@ namespace Discord.WebSocket public void AddOrUpdate(TModel model) { - _store.AddOrUpdate(model); + var userDefinedModel = GetUserDefinedModel(model); + _store.AddOrUpdate(userDefinedModel); if (TryGetReference(model.Id, out var reference)) - reference.Update(model); + reference.Update(userDefinedModel); } public ValueTask AddOrUpdateAsync(TModel model) { - if (TryGetReference(model.Id, out var reference)) - reference.Update(model); - return _store.AddOrUpdateAsync(model); + var userDefinedModel = GetUserDefinedModel(model); + if (TryGetReference(userDefinedModel.Id, out var reference)) + reference.Update(userDefinedModel); + return _store.AddOrUpdateAsync(userDefinedModel); } public void BulkAddOrUpdate(IEnumerable models) { + models = models.Select(x => GetUserDefinedModel(x)); _store.AddOrUpdateBatch(models); foreach (var model in models) { @@ -220,6 +229,7 @@ namespace Discord.WebSocket public async ValueTask BulkAddOrUpdateAsync(IEnumerable models) { + models = models.Select(x => GetUserDefinedModel(x)); await _store.AddOrUpdateBatchAsync(models).ConfigureAwait(false); foreach (var model in models) @@ -274,6 +284,7 @@ namespace Discord.WebSocket { typeof(ICurrentUserModel), () => new SocketSelfUser.CacheModel() }, { typeof(IThreadMemberModel), () => new SocketThreadUser.CacheModel() }, { typeof(IPresenceModel), () => new SocketPresence.CacheModel() }, + { typeof(IActivityModel), () => new SocketPresence.ActivityCacheModel() } }; @@ -308,7 +319,7 @@ namespace Discord.WebSocket _cacheProvider, m => SocketGuildUser.Create(guildId, _client, m), async (id, options) => await _client.Rest.GetGuildUserAsync(guildId, id, options).ConfigureAwait(false), - AllowSyncWaits); + GetModel); await store.InitializeAsync(guildId).ConfigureAwait(false); @@ -334,7 +345,7 @@ namespace Discord.WebSocket _cacheProvider, m => SocketThreadUser.Create(_client, guildId, threadId, m), async (id, options) => await ThreadHelper.GetUserAsync(id, _client.GetChannel(threadId) as SocketThreadChannel, _client, options).ConfigureAwait(false), - AllowSyncWaits); + GetModel); await store.InitializeAsync().ConfigureAwait(false); @@ -352,6 +363,13 @@ namespace Discord.WebSocket public TModel GetModel() where TFallback : class, TModel, new() + where TModel : class + { + return GetModel() ?? new TFallback(); + } + + public TModel GetModel() + where TModel : class { var type = _cacheProvider.GetModel(); @@ -363,21 +381,22 @@ namespace Discord.WebSocket return (TModel)Activator.CreateInstance(type); } else - return _defaultModelFactory.TryGetValue(typeof(TModel), out var m) ? (TModel)m() : new TFallback(); + return _defaultModelFactory.TryGetValue(typeof(TModel), out var m) ? (TModel)m() : null; } + private void CreateStores() { UserStore = new ReferenceStore( _cacheProvider, m => SocketGlobalUser.Create(_client, m), async (id, options) => await _client.Rest.GetUserAsync(id, options).ConfigureAwait(false), - AllowSyncWaits); + GetModel); PresenceStore = new ReferenceStore( _cacheProvider, m => SocketPresence.Create(_client, m), (id, options) => Task.FromResult(null), - AllowSyncWaits); + GetModel); _memberStores = new(); _threadMemberStores = new(); diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketPresence.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketPresence.cs index b6e9cd764..220e85305 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketPresence.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketPresence.cs @@ -142,11 +142,11 @@ namespace Discord.WebSocket ulong IEntityModel.Id { get => UserId; - set => throw new NotSupportedException(); + set => UserId = value; } } - private struct ActivityCacheModel : IActivityModel + internal class ActivityCacheModel : IActivityModel { public string Id { get; set; } public string Url { get; set; } @@ -173,7 +173,7 @@ namespace Discord.WebSocket public DateTimeOffset? TimestampEnd { get; set; } } - private struct EmojiCacheModel : IEmojiModel + private class EmojiCacheModel : IEmojiModel { public ulong? Id { get; set; } public string Name { get; set; } diff --git a/src/Discord.Net.WebSocket/Extensions/CacheModelExtensions.cs b/src/Discord.Net.WebSocket/Extensions/CacheModelExtensions.cs new file mode 100644 index 000000000..5b1da074c --- /dev/null +++ b/src/Discord.Net.WebSocket/Extensions/CacheModelExtensions.cs @@ -0,0 +1,56 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.RegularExpressions; + +namespace Discord.WebSocket +{ + internal static class CacheModelExtensions + { + public static TDest ToSpecifiedModel(this IEntityModel source, TDest dest) + where TId : IEquatable + where TDest : IEntityModel + { + if (source == null || dest == null) + throw new ArgumentNullException(source == null ? nameof(source) : nameof(dest)); + + // get the shared model interface + var sourceType = source.GetType(); + var destType = dest.GetType(); + + if (sourceType == destType) + return (TDest)source; + + List sharedInterfaceModels = new(); + + foreach (var intf in sourceType.GetInterfaces()) + { + if (destType.GetInterface(intf.Name) != null && intf.Name.Contains("Model")) + sharedInterfaceModels.Add(intf); + } + + if (sharedInterfaceModels.Count == 0) + throw new NotSupportedException($"cannot find common shared model interface between {sourceType.Name} and {destType.Name}"); + + foreach (var interfaceType in sharedInterfaceModels) + { + var intfName = interfaceType.GenericTypeArguments.Length == 0 ? interfaceType.FullName : + $"{interfaceType.Namespace}.{Regex.Replace(interfaceType.Name, @"`\d+?$", "")}<{string.Join(", ", interfaceType.GenericTypeArguments.Select(x => x.FullName))}>"; + + foreach (var prop in interfaceType.GetProperties()) + { + var sProp = sourceType.GetProperty($"{intfName}.{prop.Name}", BindingFlags.NonPublic | BindingFlags.Instance) ?? sourceType.GetProperty(prop.Name); + var dProp = destType.GetProperty($"{intfName}.{prop.Name}", BindingFlags.NonPublic | BindingFlags.Instance) ?? destType.GetProperty(prop.Name); + + if (sProp == null || dProp == null) + throw new NotSupportedException($"Couldn't find common interface property {prop.Name}"); + + dProp.SetValue(dest, sProp.GetValue(source)); + } + } + + return dest; + } + } +}