@@ -57,18 +57,24 @@ namespace Discord.WebSocket
private readonly ConcurrentDictionary<TId, CacheReference<TEntity>> _references = new();
private IEntityStore<TModel, TId> _store;
private Func<TModel, TEntity> _entityBuilder;
private Func<TModel> _modelFactory;
private Func<TId, RequestOptions, Task<TSharedEntity>> _restLookup;
private readonly bool _allowSyncWaits;
private readonly object _lock = new();
public ReferenceStore(ICacheProvider cacheProvider, Func<TModel, TEntity> entityBuilder, Func<TId, RequestOptions, Task<TSharedEntity>> restLookup, bool allowSyncWaits)
public ReferenceStore(ICacheProvider cacheProvider,
Func<TModel, TEntity> entityBuilder,
Func<TId, RequestOptions, Task<TSharedEntity>> restLookup,
Func<TModel> userDefinedModelFactory)
{
_allowSyncWaits = allowSyncWaits;
_cacheProvider = cacheProvider;
_entityBuilder = entityBuilder;
_restLookup = restLookup;
_modelFactory = userDefinedModelFactory;
}
private TModel GetUserDefinedModel(TModel t)
=> t.ToSpecifiedModel(_modelFactory());
internal bool RemoveReference(TId id)
{
if(_references.TryGetValue(id, out var rf))
@@ -196,20 +202,23 @@ namespace Discord.WebSocket
public void AddOrUpdate(TModel model)
{
_store.AddOrUpdate(model);
var userDefinedModel = GetUserDefinedModel(model);
_store.AddOrUpdate(userDefinedModel);
if (TryGetReference(model.Id, out var reference))
reference.Update(m odel);
reference.Update(userDefinedM odel);
}
public ValueTask AddOrUpdateAsync(TModel model)
{
if (TryGetReference(model.Id, out var reference))
reference.Update(model);
return _store.AddOrUpdateAsync(model);
var userDefinedModel = GetUserDefinedModel(model);
if (TryGetReference(userDefinedModel.Id, out var reference))
reference.Update(userDefinedModel);
return _store.AddOrUpdateAsync(userDefinedModel);
}
public void BulkAddOrUpdate(IEnumerable<TModel> models)
{
models = models.Select(x => GetUserDefinedModel(x));
_store.AddOrUpdateBatch(models);
foreach (var model in models)
{
@@ -220,6 +229,7 @@ namespace Discord.WebSocket
public async ValueTask BulkAddOrUpdateAsync(IEnumerable<TModel> models)
{
models = models.Select(x => GetUserDefinedModel(x));
await _store.AddOrUpdateBatchAsync(models).ConfigureAwait(false);
foreach (var model in models)
@@ -274,6 +284,7 @@ namespace Discord.WebSocket
{ typeof(ICurrentUserModel), () => new SocketSelfUser.CacheModel() },
{ typeof(IThreadMemberModel), () => new SocketThreadUser.CacheModel() },
{ typeof(IPresenceModel), () => new SocketPresence.CacheModel() },
{ typeof(IActivityModel), () => new SocketPresence.ActivityCacheModel() }
};
@@ -308,7 +319,7 @@ namespace Discord.WebSocket
_cacheProvider,
m => SocketGuildUser.Create(guildId, _client, m),
async (id, options) => await _client.Rest.GetGuildUserAsync(guildId, id, options).ConfigureAwait(false),
AllowSyncWaits );
GetModel<IMemberModel> );
await store.InitializeAsync(guildId).ConfigureAwait(false);
@@ -334,7 +345,7 @@ namespace Discord.WebSocket
_cacheProvider,
m => SocketThreadUser.Create(_client, guildId, threadId, m),
async (id, options) => await ThreadHelper.GetUserAsync(id, _client.GetChannel(threadId) as SocketThreadChannel, _client, options).ConfigureAwait(false),
AllowSyncWaits );
GetModel<IThreadMemberModel> );
await store.InitializeAsync().ConfigureAwait(false);
@@ -352,6 +363,13 @@ namespace Discord.WebSocket
public TModel GetModel<TModel, TFallback>()
where TFallback : class, TModel, new()
where TModel : class
{
return GetModel<TModel>() ?? new TFallback();
}
public TModel GetModel<TModel>()
where TModel : class
{
var type = _cacheProvider.GetModel<TModel>();
@@ -363,21 +381,22 @@ namespace Discord.WebSocket
return (TModel)Activator.CreateInstance(type);
}
else
return _defaultModelFactory.TryGetValue(typeof(TModel), out var m) ? (TModel)m() : new TFallback() ;
return _defaultModelFactory.TryGetValue(typeof(TModel), out var m) ? (TModel)m() : null ;
}
private void CreateStores()
{
UserStore = new ReferenceStore<SocketGlobalUser, IUserModel, ulong, IUser>(
_cacheProvider,
m => SocketGlobalUser.Create(_client, m),
async (id, options) => await _client.Rest.GetUserAsync(id, options).ConfigureAwait(false),
AllowSyncWaits );
GetModel<IUserModel> );
PresenceStore = new ReferenceStore<SocketPresence, IPresenceModel, ulong, IPresence>(
_cacheProvider,
m => SocketPresence.Create(_client, m),
(id, options) => Task.FromResult<IPresence>(null),
AllowSyncWaits );
GetModel<IPresenceModel> );
_memberStores = new();
_threadMemberStores = new();