Browse Source

custom model factory

v4/state-cache-providers
Quin Lynch 3 years ago
parent
commit
dc2dafa3ac
5 changed files with 116 additions and 32 deletions
  1. +24
    -0
      src/Discord.Net.Core/Cache/Models/Message/IMessageModel.cs
  2. +1
    -16
      src/Discord.Net.WebSocket/Cache/DefaultConcurrentCacheProvider.cs
  3. +32
    -13
      src/Discord.Net.WebSocket/ClientStateManager.Experiment.cs
  4. +3
    -3
      src/Discord.Net.WebSocket/Entities/Users/SocketPresence.cs
  5. +56
    -0
      src/Discord.Net.WebSocket/Extensions/CacheModelExtensions.cs

+ 24
- 0
src/Discord.Net.Core/Cache/Models/Message/IMessageModel.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Discord
{
public interface IMessageModel : IEntityModel<ulong>
{
MessageType Type { get; set; }
ulong ChannelId { get; set; }
ulong? GuildId { get; set; }
ulong AuthorId { get; set; }
bool IsWebhookMessage { get; set; }
string Content { get; set; }
DateTimeOffset Timestamp { get; set; }
DateTimeOffset? EditedTimestamp { get; set; }
bool IsTextToSpeech { get; set; }
bool MentionEveryone { get; set; }
ulong[] UserMentionIds { get; set; }
ulong[] RoleMentionIds { get; set; }
}
}

+ 1
- 16
src/Discord.Net.WebSocket/Cache/DefaultConcurrentCacheProvider.cs View File

@@ -12,16 +12,6 @@ namespace Discord.WebSocket
private readonly ConcurrentDictionary<Type, object> _storeCache = new();
private readonly ConcurrentDictionary<object, object> _subStoreCache = new();

private readonly Dictionary<Type, Type> _models = new()
{
{ typeof(IUserModel), typeof(API.User) },
{ typeof(ICurrentUserModel), typeof(API.CurrentUser) },
{ typeof(IMemberModel), typeof(API.GuildMember) },
{ typeof(IThreadMemberModel), typeof(API.ThreadMember)},
{ typeof(IPresenceModel), typeof(API.Presence)},
{ typeof(IActivityModel), typeof(API.Game)}
};

private class DefaultEntityStore<TModel, TId> : IEntityStore<TModel, TId>
where TModel : IEntityModel<TId>
where TId : IEquatable<TId>
@@ -94,12 +84,7 @@ namespace Discord.WebSocket
}
}

public Type GetModel<TInterface>()
{
if (_models.TryGetValue(typeof(TInterface), out var t))
return t;
return null;
}
public Type GetModel<TInterface>() => null;

public virtual ValueTask<IEntityStore<TModel, TId>> GetStoreAsync<TModel, TId>()
where TModel : IEntityModel<TId>


+ 32
- 13
src/Discord.Net.WebSocket/ClientStateManager.Experiment.cs View File

@@ -57,18 +57,24 @@ namespace Discord.WebSocket
private readonly ConcurrentDictionary<TId, CacheReference<TEntity>> _references = new();
private IEntityStore<TModel, TId> _store;
private Func<TModel, TEntity> _entityBuilder;
private Func<TModel> _modelFactory;
private Func<TId, RequestOptions, Task<TSharedEntity>> _restLookup;
private readonly bool _allowSyncWaits;
private readonly object _lock = new();

public ReferenceStore(ICacheProvider cacheProvider, Func<TModel, TEntity> entityBuilder, Func<TId, RequestOptions, Task<TSharedEntity>> restLookup, bool allowSyncWaits)
public ReferenceStore(ICacheProvider cacheProvider,
Func<TModel, TEntity> entityBuilder,
Func<TId, RequestOptions, Task<TSharedEntity>> restLookup,
Func<TModel> userDefinedModelFactory)
{
_allowSyncWaits = allowSyncWaits;
_cacheProvider = cacheProvider;
_entityBuilder = entityBuilder;
_restLookup = restLookup;
_modelFactory = userDefinedModelFactory;
}

private TModel GetUserDefinedModel(TModel t)
=> t.ToSpecifiedModel(_modelFactory());

internal bool RemoveReference(TId id)
{
if(_references.TryGetValue(id, out var rf))
@@ -196,20 +202,23 @@ namespace Discord.WebSocket

public void AddOrUpdate(TModel model)
{
_store.AddOrUpdate(model);
var userDefinedModel = GetUserDefinedModel(model);
_store.AddOrUpdate(userDefinedModel);
if (TryGetReference(model.Id, out var reference))
reference.Update(model);
reference.Update(userDefinedModel);
}

public ValueTask AddOrUpdateAsync(TModel model)
{
if (TryGetReference(model.Id, out var reference))
reference.Update(model);
return _store.AddOrUpdateAsync(model);
var userDefinedModel = GetUserDefinedModel(model);
if (TryGetReference(userDefinedModel.Id, out var reference))
reference.Update(userDefinedModel);
return _store.AddOrUpdateAsync(userDefinedModel);
}

public void BulkAddOrUpdate(IEnumerable<TModel> models)
{
models = models.Select(x => GetUserDefinedModel(x));
_store.AddOrUpdateBatch(models);
foreach (var model in models)
{
@@ -220,6 +229,7 @@ namespace Discord.WebSocket

public async ValueTask BulkAddOrUpdateAsync(IEnumerable<TModel> models)
{
models = models.Select(x => GetUserDefinedModel(x));
await _store.AddOrUpdateBatchAsync(models).ConfigureAwait(false);

foreach (var model in models)
@@ -274,6 +284,7 @@ namespace Discord.WebSocket
{ typeof(ICurrentUserModel), () => new SocketSelfUser.CacheModel() },
{ typeof(IThreadMemberModel), () => new SocketThreadUser.CacheModel() },
{ typeof(IPresenceModel), () => new SocketPresence.CacheModel() },
{ typeof(IActivityModel), () => new SocketPresence.ActivityCacheModel() }
};


@@ -308,7 +319,7 @@ namespace Discord.WebSocket
_cacheProvider,
m => SocketGuildUser.Create(guildId, _client, m),
async (id, options) => await _client.Rest.GetGuildUserAsync(guildId, id, options).ConfigureAwait(false),
AllowSyncWaits);
GetModel<IMemberModel>);

await store.InitializeAsync(guildId).ConfigureAwait(false);

@@ -334,7 +345,7 @@ namespace Discord.WebSocket
_cacheProvider,
m => SocketThreadUser.Create(_client, guildId, threadId, m),
async (id, options) => await ThreadHelper.GetUserAsync(id, _client.GetChannel(threadId) as SocketThreadChannel, _client, options).ConfigureAwait(false),
AllowSyncWaits);
GetModel<IThreadMemberModel>);

await store.InitializeAsync().ConfigureAwait(false);

@@ -352,6 +363,13 @@ namespace Discord.WebSocket

public TModel GetModel<TModel, TFallback>()
where TFallback : class, TModel, new()
where TModel : class
{
return GetModel<TModel>() ?? new TFallback();
}

public TModel GetModel<TModel>()
where TModel : class
{
var type = _cacheProvider.GetModel<TModel>();

@@ -363,21 +381,22 @@ namespace Discord.WebSocket
return (TModel)Activator.CreateInstance(type);
}
else
return _defaultModelFactory.TryGetValue(typeof(TModel), out var m) ? (TModel)m() : new TFallback();
return _defaultModelFactory.TryGetValue(typeof(TModel), out var m) ? (TModel)m() : null;
}

private void CreateStores()
{
UserStore = new ReferenceStore<SocketGlobalUser, IUserModel, ulong, IUser>(
_cacheProvider,
m => SocketGlobalUser.Create(_client, m),
async (id, options) => await _client.Rest.GetUserAsync(id, options).ConfigureAwait(false),
AllowSyncWaits);
GetModel<IUserModel>);

PresenceStore = new ReferenceStore<SocketPresence, IPresenceModel, ulong, IPresence>(
_cacheProvider,
m => SocketPresence.Create(_client, m),
(id, options) => Task.FromResult<IPresence>(null),
AllowSyncWaits);
GetModel<IPresenceModel>);

_memberStores = new();
_threadMemberStores = new();


+ 3
- 3
src/Discord.Net.WebSocket/Entities/Users/SocketPresence.cs View File

@@ -142,11 +142,11 @@ namespace Discord.WebSocket
ulong IEntityModel<ulong>.Id
{
get => UserId;
set => throw new NotSupportedException();
set => UserId = value;
}
}

private struct ActivityCacheModel : IActivityModel
internal class ActivityCacheModel : IActivityModel
{
public string Id { get; set; }
public string Url { get; set; }
@@ -173,7 +173,7 @@ namespace Discord.WebSocket
public DateTimeOffset? TimestampEnd { get; set; }
}

private struct EmojiCacheModel : IEmojiModel
private class EmojiCacheModel : IEmojiModel
{
public ulong? Id { get; set; }
public string Name { get; set; }


+ 56
- 0
src/Discord.Net.WebSocket/Extensions/CacheModelExtensions.cs View File

@@ -0,0 +1,56 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;

namespace Discord.WebSocket
{
internal static class CacheModelExtensions
{
public static TDest ToSpecifiedModel<TId, TDest>(this IEntityModel<TId> source, TDest dest)
where TId : IEquatable<TId>
where TDest : IEntityModel<TId>
{
if (source == null || dest == null)
throw new ArgumentNullException(source == null ? nameof(source) : nameof(dest));

// get the shared model interface
var sourceType = source.GetType();
var destType = dest.GetType();

if (sourceType == destType)
return (TDest)source;

List<Type> sharedInterfaceModels = new();

foreach (var intf in sourceType.GetInterfaces())
{
if (destType.GetInterface(intf.Name) != null && intf.Name.Contains("Model"))
sharedInterfaceModels.Add(intf);
}

if (sharedInterfaceModels.Count == 0)
throw new NotSupportedException($"cannot find common shared model interface between {sourceType.Name} and {destType.Name}");

foreach (var interfaceType in sharedInterfaceModels)
{
var intfName = interfaceType.GenericTypeArguments.Length == 0 ? interfaceType.FullName :
$"{interfaceType.Namespace}.{Regex.Replace(interfaceType.Name, @"`\d+?$", "")}<{string.Join(", ", interfaceType.GenericTypeArguments.Select(x => x.FullName))}>";

foreach (var prop in interfaceType.GetProperties())
{
var sProp = sourceType.GetProperty($"{intfName}.{prop.Name}", BindingFlags.NonPublic | BindingFlags.Instance) ?? sourceType.GetProperty(prop.Name);
var dProp = destType.GetProperty($"{intfName}.{prop.Name}", BindingFlags.NonPublic | BindingFlags.Instance) ?? destType.GetProperty(prop.Name);

if (sProp == null || dProp == null)
throw new NotSupportedException($"Couldn't find common interface property {prop.Name}");

dProp.SetValue(dest, sProp.GetValue(source));
}
}

return dest;
}
}
}

Loading…
Cancel
Save