@@ -92,28 +92,6 @@ namespace Discord.WebSocket
}
}
private TResult RunOrThrowValueTask<TResult>(ValueTask<TResult> t)
{
if (_allowSyncWaits)
{
return t.GetAwaiter().GetResult();
}
else if (t.IsCompleted)
return t.Result;
else
throw new InvalidOperationException("Cannot run asynchronous value task in synchronous context");
}
private void RunOrThrowValueTask(ValueTask t)
{
if (_allowSyncWaits)
{
t.GetAwaiter().GetResult();
}
else if (!t.IsCompleted)
throw new InvalidOperationException("Cannot run asynchronous value task in synchronous context");
}
public async ValueTask InitializeAsync()
{
_store ??= await _cacheProvider.GetStoreAsync<TModel, TId>().ConfigureAwait(false);
@@ -137,7 +115,7 @@ namespace Discord.WebSocket
return entity;
}
var model = RunOrThrowValueTask( _store.GetAsync (id, CacheRunMode.Sync) );
var model = _store.Get(id);
if (model != null)
{
@@ -156,7 +134,7 @@ namespace Discord.WebSocket
return entity;
}
var model = await _store.GetAsync(id, CacheRunMode.Async ).ConfigureAwait(false);
var model = await _store.GetAsync(id).ConfigureAwait(false);
if (model != null)
{
@@ -175,7 +153,7 @@ namespace Discord.WebSocket
public IEnumerable<TEntity> GetAll()
{
var models = RunOrThrowValueTask( _store.GetAllAsync (CacheRunMode.Sync).ToArrayAsync() );
var models = _store.GetAll();
return models.Select(x =>
{
var entity = _entityBuilder(x);
@@ -186,7 +164,7 @@ namespace Discord.WebSocket
public async IAsyncEnumerable<TEntity> GetAllAsync()
{
await foreach(var model in _store.GetAllAsync(CacheRunMode.Async ))
await foreach(var model in _store.GetAllAsync())
{
var entity = _entityBuilder(model);
_references.TryAdd(model.Id, new CacheReference<TEntity>(entity));
@@ -212,13 +190,13 @@ namespace Discord.WebSocket
return (TEntity)entity;
var model = valueFactory(id);
await AddOrUpdateAsync(model);
await AddOrUpdateAsync(model).ConfigureAwait(false) ;
return _entityBuilder(model);
}
public void AddOrUpdate(TModel model)
{
RunOrThrowValueTask( _store.AddOrUpdateAsync (model, CacheRunMode.Sync) );
_store.AddOrUpdate(model);
if (TryGetReference(model.Id, out var reference))
reference.Update(model);
}
@@ -227,14 +205,13 @@ namespace Discord.WebSocket
{
if (TryGetReference(model.Id, out var reference))
reference.Update(model);
return _store.AddOrUpdateAsync(model, CacheRunMode.Async );
return _store.AddOrUpdateAsync(model);
}
public void BulkAddOrUpdate(IEnumerable<TModel> models)
{
RunOrThrowValueTask(_store.AddOrUpdateBatchAsync(models, CacheRunMode.Sync));
foreach(var model in models)
_store.AddOrUpdateBatch(models);
foreach (var model in models)
{
if (_references.TryGetValue(model.Id, out var rf) && rf.Reference.TryGetTarget(out var entity))
entity.Update(model);
@@ -243,7 +220,7 @@ namespace Discord.WebSocket
public async ValueTask BulkAddOrUpdateAsync(IEnumerable<TModel> models)
{
await _store.AddOrUpdateBatchAsync(models, CacheRunMode.Async ).ConfigureAwait(false);
await _store.AddOrUpdateBatchAsync(models).ConfigureAwait(false);
foreach (var model in models)
{
@@ -254,26 +231,26 @@ namespace Discord.WebSocket
public void Remove(TId id)
{
RunOrThrowValueTask( _store.RemoveAsync (id, CacheRunMode.Sync) );
_store.Remove(id);
_references.TryRemove(id, out _);
}
public ValueTask RemoveAsync(TId id)
{
_references.TryRemove(id, out _);
return _store.RemoveAsync(id, CacheRunMode.Async );
return _store.RemoveAsync(id);
}
public void Purge()
{
RunOrThrowValueTask( _store.PurgeAllAsync (CacheRunMode.Sync) );
_store.PurgeAll();
_references.Clear();
}
public ValueTask PurgeAsync()
{
_references.Clear();
return _store.PurgeAllAsync(CacheRunMode.Async );
return _store.PurgeAllAsync();
}
TEntity ILookupReferenceStore<TEntity, TId>.Get(TId id) => Get(id);
@@ -380,5 +357,24 @@ namespace Discord.WebSocket
_threadMemberLock.Release();
}
}
public ReferenceStore<SocketThreadUser, IThreadMemberModel, ulong, IThreadUser> GetThreadMemberStore(ulong threadId)
=> _threadMemberStores.TryGetValue(threadId, out var store) ? store : null;
public TModel GetModel<TModel, TFallback>()
where TFallback : class, TModel, new()
{
var type = _cacheProvider.GetModel<TModel>();
if (type != null)
{
if (!type.GetInterfaces().Contains(typeof(TModel)))
throw new InvalidOperationException($"Cannot use {type.Name} as a model for {typeof(TModel).Name}");
return (TModel)Activator.CreateInstance(type);
}
else
return new TFallback();
}
}
}