|
|
@@ -1,163 +1,342 @@ |
|
|
|
using Discord.Rest; |
|
|
|
using System; |
|
|
|
using System.Collections.Concurrent; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Runtime.CompilerServices; |
|
|
|
using System.Text; |
|
|
|
using System.Threading; |
|
|
|
using System.Threading.Tasks; |
|
|
|
|
|
|
|
namespace Discord.WebSocket |
|
|
|
{ |
|
|
|
internal class CacheWeakReference<T> : WeakReference |
|
|
|
internal class CacheReference<TType> where TType : class |
|
|
|
{ |
|
|
|
public new T Target { get => (T)base.Target; set => base.Target = value; } |
|
|
|
public CacheWeakReference(T target) |
|
|
|
: base(target, false) |
|
|
|
public WeakReference<TType> Reference { get; } |
|
|
|
|
|
|
|
public bool CanRelease |
|
|
|
=> !Reference.TryGetTarget(out _) || _referenceCount <= 0; |
|
|
|
|
|
|
|
private int _referenceCount; |
|
|
|
|
|
|
|
private readonly object _lock = new object(); |
|
|
|
|
|
|
|
public CacheReference(TType value) |
|
|
|
{ |
|
|
|
Reference = new(value); |
|
|
|
_referenceCount = 1; |
|
|
|
} |
|
|
|
|
|
|
|
public bool TryObtainReference(out TType reference) |
|
|
|
{ |
|
|
|
if (Reference.TryGetTarget(out reference)) |
|
|
|
{ |
|
|
|
Interlocked.Increment(ref _referenceCount); |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
public bool TryGetTarget(out T target) |
|
|
|
public void ReleaseReference() |
|
|
|
{ |
|
|
|
target = Target; |
|
|
|
return IsAlive; |
|
|
|
lock (_lock) |
|
|
|
{ |
|
|
|
if (_referenceCount > 0) |
|
|
|
_referenceCount--; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
internal partial class ClientStateManager |
|
|
|
internal class ReferenceStore<TEntity, TModel, TId, ISharedEntity> |
|
|
|
where TEntity : class, ICached<TModel>, ISharedEntity |
|
|
|
where TModel : IEntityModel<TId> |
|
|
|
where TId : IEquatable<TId> |
|
|
|
where ISharedEntity : class |
|
|
|
{ |
|
|
|
private readonly ConcurrentDictionary<ulong, CacheWeakReference<SocketGlobalUser>> _userReferences = new(); |
|
|
|
private readonly ConcurrentDictionary<(ulong GuildId, ulong UserId), CacheWeakReference<SocketGuildUser>> _memberReferences = new(); |
|
|
|
|
|
|
|
|
|
|
|
#region Helpers |
|
|
|
|
|
|
|
private void EnsureSync(ValueTask vt) |
|
|
|
private readonly ICacheProvider _cacheProvider; |
|
|
|
private readonly ConcurrentDictionary<TId, CacheReference<TEntity>> _references = new(); |
|
|
|
private IEntityStore<TModel, TId> _store; |
|
|
|
private Func<TModel, TEntity> _entityBuilder; |
|
|
|
private Func<TId, RequestOptions, Task<ISharedEntity>> _restLookup; |
|
|
|
private readonly bool _allowSyncWaits; |
|
|
|
private readonly object _lock = new(); |
|
|
|
|
|
|
|
public ReferenceStore(ICacheProvider cacheProvider, Func<TModel, TEntity> entityBuilder, Func<TId, RequestOptions, Task<ISharedEntity>> restLookup, bool allowSyncWaits) |
|
|
|
{ |
|
|
|
if (!vt.IsCompleted) |
|
|
|
throw new NotSupportedException($"Cannot use async context for value task lookup"); |
|
|
|
_allowSyncWaits = allowSyncWaits; |
|
|
|
_cacheProvider = cacheProvider; |
|
|
|
_entityBuilder = entityBuilder; |
|
|
|
_restLookup = restLookup; |
|
|
|
} |
|
|
|
|
|
|
|
#endregion |
|
|
|
internal void ClearDeadReferences() |
|
|
|
{ |
|
|
|
lock (_lock) |
|
|
|
{ |
|
|
|
var references = _references.Where(x => x.Value.CanRelease).ToArray(); |
|
|
|
foreach (var reference in references) |
|
|
|
_references.TryRemove(reference.Key, out _); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#region Global users |
|
|
|
internal void RemoveReferencedGlobalUser(ulong id) |
|
|
|
private TResult RunOrThrowValueTask<TResult>(ValueTask<TResult> t) |
|
|
|
{ |
|
|
|
Console.WriteLine("Global user untracked"); |
|
|
|
_userReferences.TryRemove(id, out _); |
|
|
|
if (_allowSyncWaits) |
|
|
|
{ |
|
|
|
return t.GetAwaiter().GetResult(); |
|
|
|
} |
|
|
|
else if (t.IsCompleted) |
|
|
|
return t.Result; |
|
|
|
else |
|
|
|
throw new InvalidOperationException("Cannot run asynchronous value task in synchronous context"); |
|
|
|
} |
|
|
|
|
|
|
|
private void TrackGlobalUser(ulong id, SocketGlobalUser user) |
|
|
|
private void RunOrThrowValueTask(ValueTask t) |
|
|
|
{ |
|
|
|
if (user != null) |
|
|
|
if (_allowSyncWaits) |
|
|
|
{ |
|
|
|
_userReferences.TryAdd(id, new CacheWeakReference<SocketGlobalUser>(user)); |
|
|
|
t.GetAwaiter().GetResult(); |
|
|
|
} |
|
|
|
else if (!t.IsCompleted) |
|
|
|
throw new InvalidOperationException("Cannot run asynchronous value task in synchronous context"); |
|
|
|
} |
|
|
|
|
|
|
|
internal ValueTask<IUser> GetUserAsync(ulong id, CacheMode mode = CacheMode.AllowDownload, RequestOptions options = null) |
|
|
|
=> _state.GetUserAsync(id, mode.ToBehavior(), options); |
|
|
|
public async ValueTask InitializeAsync() |
|
|
|
{ |
|
|
|
_store ??= await _cacheProvider.GetStoreAsync<TModel, TId>().ConfigureAwait(false); |
|
|
|
} |
|
|
|
|
|
|
|
public async ValueTask InitializeAsync(TId parentId) |
|
|
|
{ |
|
|
|
_store ??= await _cacheProvider.GetSubStoreAsync<TModel, TId>(parentId).ConfigureAwait(false); |
|
|
|
} |
|
|
|
|
|
|
|
internal SocketGlobalUser GetUser(ulong id) |
|
|
|
private bool TryGetReference(TId id, out TEntity entity) |
|
|
|
{ |
|
|
|
if (_userReferences.TryGetValue(id, out var userRef) && userRef.TryGetTarget(out var user)) |
|
|
|
return user; |
|
|
|
entity = null; |
|
|
|
return _references.TryGetValue(id, out var reference) && reference.TryObtainReference(out entity); |
|
|
|
} |
|
|
|
|
|
|
|
public TEntity Get(TId id) |
|
|
|
{ |
|
|
|
if(TryGetReference(id, out var entity)) |
|
|
|
{ |
|
|
|
return entity; |
|
|
|
} |
|
|
|
|
|
|
|
user = (SocketGlobalUser)_state.GetUserAsync(id, StateBehavior.SyncOnly).Result; |
|
|
|
var model = RunOrThrowValueTask(_store.GetAsync(id, CacheRunMode.Sync)); |
|
|
|
|
|
|
|
if(user != null) |
|
|
|
TrackGlobalUser(id, user); |
|
|
|
if (model != null) |
|
|
|
{ |
|
|
|
entity = _entityBuilder(model); |
|
|
|
_references.TryAdd(id, new CacheReference<TEntity>(entity)); |
|
|
|
return entity; |
|
|
|
} |
|
|
|
|
|
|
|
return user; |
|
|
|
return null; |
|
|
|
} |
|
|
|
|
|
|
|
internal SocketGlobalUser GetOrAddUser(ulong id, Func<ulong, SocketGlobalUser> userFactory) |
|
|
|
public async ValueTask<ISharedEntity> GetAsync(TId id, CacheMode mode, RequestOptions options = null) |
|
|
|
{ |
|
|
|
if (_userReferences.TryGetValue(id, out var userRef) && userRef.TryGetTarget(out var user)) |
|
|
|
return user; |
|
|
|
if (TryGetReference(id, out var entity)) |
|
|
|
{ |
|
|
|
return entity; |
|
|
|
} |
|
|
|
|
|
|
|
var model = await _store.GetAsync(id, CacheRunMode.Async).ConfigureAwait(false); |
|
|
|
|
|
|
|
user = GetUser(id); |
|
|
|
if (model != null) |
|
|
|
{ |
|
|
|
entity = _entityBuilder(model); |
|
|
|
_references.TryAdd(id, new CacheReference<TEntity>(entity)); |
|
|
|
return entity; |
|
|
|
} |
|
|
|
|
|
|
|
if (user == null) |
|
|
|
if(mode == CacheMode.AllowDownload) |
|
|
|
{ |
|
|
|
user ??= userFactory(id); |
|
|
|
_state.AddOrUpdateUserAsync(user); |
|
|
|
TrackGlobalUser(id, user); |
|
|
|
return await _restLookup(id, options).ConfigureAwait(false); |
|
|
|
} |
|
|
|
|
|
|
|
return user; |
|
|
|
return null; |
|
|
|
} |
|
|
|
|
|
|
|
internal void RemoveUser(ulong id) |
|
|
|
public IEnumerable<TEntity> GetAll() |
|
|
|
{ |
|
|
|
_state.RemoveUserAsync(id); |
|
|
|
var models = RunOrThrowValueTask(_store.GetAllAsync(CacheRunMode.Sync).ToArrayAsync()); |
|
|
|
return models.Select(x => |
|
|
|
{ |
|
|
|
var entity = _entityBuilder(x); |
|
|
|
_references.TryAdd(x.Id, new CacheReference<TEntity>(entity)); |
|
|
|
return entity; |
|
|
|
}); |
|
|
|
} |
|
|
|
#endregion |
|
|
|
|
|
|
|
#region GuildUsers |
|
|
|
private void TrackMember(ulong userId, ulong guildId, SocketGuildUser user) |
|
|
|
public async IAsyncEnumerable<TEntity> GetAllAsync() |
|
|
|
{ |
|
|
|
if(user != null) |
|
|
|
await foreach(var model in _store.GetAllAsync(CacheRunMode.Async)) |
|
|
|
{ |
|
|
|
_memberReferences.TryAdd((guildId, userId), new CacheWeakReference<SocketGuildUser>(user)); |
|
|
|
var entity = _entityBuilder(model); |
|
|
|
_references.TryAdd(model.Id, new CacheReference<TEntity>(entity)); |
|
|
|
yield return entity; |
|
|
|
} |
|
|
|
} |
|
|
|
internal void RemovedReferencedMember(ulong userId, ulong guildId) |
|
|
|
=> _memberReferences.TryRemove((guildId, userId), out _); |
|
|
|
|
|
|
|
internal ValueTask<IGuildUser> GetMemberAsync(ulong userId, ulong guildId, CacheMode mode = CacheMode.AllowDownload, RequestOptions options = null) |
|
|
|
=> _state.GetMemberAsync(guildId, userId, mode.ToBehavior(), options); |
|
|
|
public TEntity GetOrAdd(TId id, Func<TId, TModel> valueFactory) |
|
|
|
{ |
|
|
|
var entity = Get(id); |
|
|
|
if (entity != null) |
|
|
|
return entity; |
|
|
|
|
|
|
|
var model = valueFactory(id); |
|
|
|
AddOrUpdate(model); |
|
|
|
return _entityBuilder(model); |
|
|
|
} |
|
|
|
|
|
|
|
public async ValueTask<TEntity> GetOrAddAsync(TId id, Func<TId, TModel> valueFactory) |
|
|
|
{ |
|
|
|
var entity = await GetAsync(id, CacheMode.CacheOnly).ConfigureAwait(false); |
|
|
|
if (entity != null) |
|
|
|
return (TEntity)entity; |
|
|
|
|
|
|
|
var model = valueFactory(id); |
|
|
|
await AddOrUpdateAsync(model); |
|
|
|
return _entityBuilder(model); |
|
|
|
} |
|
|
|
|
|
|
|
internal SocketGuildUser GetMember(ulong userId, ulong guildId) |
|
|
|
public void AddOrUpdate(TModel model) |
|
|
|
{ |
|
|
|
if (_memberReferences.TryGetValue((guildId, userId), out var memberRef) && memberRef.TryGetTarget(out var member)) |
|
|
|
return member; |
|
|
|
member = (SocketGuildUser)_state.GetMemberAsync(guildId, userId, StateBehavior.SyncOnly).Result; |
|
|
|
if(member != null) |
|
|
|
TrackMember(userId, guildId, member); |
|
|
|
return member; |
|
|
|
RunOrThrowValueTask(_store.AddOrUpdateAsync(model, CacheRunMode.Sync)); |
|
|
|
if (TryGetReference(model.Id, out var reference)) |
|
|
|
reference.Update(model); |
|
|
|
} |
|
|
|
|
|
|
|
internal SocketGuildUser GetOrAddMember(ulong userId, ulong guildId, Func<ulong, ulong, SocketGuildUser> memberFactory) |
|
|
|
public ValueTask AddOrUpdateAsync(TModel model) |
|
|
|
{ |
|
|
|
if (_memberReferences.TryGetValue((guildId, userId), out var memberRef) && memberRef.TryGetTarget(out var member)) |
|
|
|
return member; |
|
|
|
if (TryGetReference(model.Id, out var reference)) |
|
|
|
reference.Update(model); |
|
|
|
return _store.AddOrUpdateAsync(model, CacheRunMode.Async); |
|
|
|
} |
|
|
|
|
|
|
|
member = GetMember(userId, guildId); |
|
|
|
public void Remove(TId id) |
|
|
|
{ |
|
|
|
RunOrThrowValueTask(_store.RemoveAsync(id, CacheRunMode.Sync)); |
|
|
|
_references.TryRemove(id, out _); |
|
|
|
} |
|
|
|
|
|
|
|
if (member == null) |
|
|
|
{ |
|
|
|
member ??= memberFactory(userId, guildId); |
|
|
|
TrackMember(userId, guildId, member); |
|
|
|
Task.Run(async () => await _state.AddOrUpdateMemberAsync(guildId, member)); // can run async, think of this as fire and forget. |
|
|
|
} |
|
|
|
public ValueTask RemoveAsync(TId id) |
|
|
|
{ |
|
|
|
_references.TryRemove(id, out _); |
|
|
|
return _store.RemoveAsync(id, CacheRunMode.Async); |
|
|
|
} |
|
|
|
|
|
|
|
return member; |
|
|
|
public void Purge() |
|
|
|
{ |
|
|
|
RunOrThrowValueTask(_store.PurgeAllAsync(CacheRunMode.Sync)); |
|
|
|
_references.Clear(); |
|
|
|
} |
|
|
|
|
|
|
|
internal IEnumerable<IGuildUser> GetMembers(ulong guildId) |
|
|
|
=> _state.GetMembersAsync(guildId, StateBehavior.SyncOnly).Result; |
|
|
|
public ValueTask PurgeAsync() |
|
|
|
{ |
|
|
|
_references.Clear(); |
|
|
|
return _store.PurgeAllAsync(CacheRunMode.Async); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
internal void AddOrUpdateMember(ulong guildId, SocketGuildUser user) |
|
|
|
=> EnsureSync(_state.AddOrUpdateMemberAsync(guildId, user)); |
|
|
|
internal partial class ClientStateManager |
|
|
|
{ |
|
|
|
public ReferenceStore<SocketGlobalUser, IUserModel, ulong, IUser> UserStore; |
|
|
|
public ReferenceStore<SocketPresence, IPresenceModel, ulong, IPresence> PresenceStore; |
|
|
|
private ConcurrentDictionary<ulong, ReferenceStore<SocketGuildUser, IMemberModel, ulong, IGuildUser>> _memberStores; |
|
|
|
private ConcurrentDictionary<ulong, ReferenceStore<SocketThreadUser, IThreadMemberModel, ulong, IThreadUser>> _threadMemberStores; |
|
|
|
|
|
|
|
internal void RemoveMember(ulong userId, ulong guildId) |
|
|
|
=> EnsureSync(_state.RemoveMemberAsync(guildId, userId)); |
|
|
|
private SemaphoreSlim _memberStoreLock; |
|
|
|
private SemaphoreSlim _threadMemberLock; |
|
|
|
|
|
|
|
#endregion |
|
|
|
private void CreateStores() |
|
|
|
{ |
|
|
|
UserStore = new ReferenceStore<SocketGlobalUser, IUserModel, ulong, IUser>( |
|
|
|
_cacheProvider, |
|
|
|
m => SocketGlobalUser.Create(_client, m), |
|
|
|
async (id, options) => await _client.Rest.GetUserAsync(id, options).ConfigureAwait(false), |
|
|
|
AllowSyncWaits); |
|
|
|
|
|
|
|
PresenceStore = new ReferenceStore<SocketPresence, IPresenceModel, ulong, IPresence>( |
|
|
|
_cacheProvider, |
|
|
|
m => SocketPresence.Create(m), |
|
|
|
(id, options) => Task.FromResult<IPresence>(null), |
|
|
|
AllowSyncWaits); |
|
|
|
|
|
|
|
_memberStores = new(); |
|
|
|
_threadMemberStores = new(); |
|
|
|
|
|
|
|
_threadMemberLock = new(1, 1); |
|
|
|
_memberStoreLock = new(1,1); |
|
|
|
} |
|
|
|
|
|
|
|
#region Presence |
|
|
|
internal void AddOrUpdatePresence(SocketPresence presence) |
|
|
|
public void ClearDeadReferences() |
|
|
|
{ |
|
|
|
EnsureSync(_state.AddOrUpdatePresenseAsync(presence.UserId, presence, StateBehavior.SyncOnly)); |
|
|
|
UserStore.ClearDeadReferences(); |
|
|
|
PresenceStore.ClearDeadReferences(); |
|
|
|
} |
|
|
|
|
|
|
|
internal SocketPresence GetPresence(ulong userId) |
|
|
|
public async ValueTask InitializeAsync() |
|
|
|
{ |
|
|
|
if (_state.GetPresenceAsync(userId, StateBehavior.SyncOnly).Result is not SocketPresence socketPresence) |
|
|
|
throw new NotSupportedException("Cannot use non-socket entity for presence"); |
|
|
|
await UserStore.InitializeAsync(); |
|
|
|
await PresenceStore.InitializeAsync(); |
|
|
|
} |
|
|
|
|
|
|
|
public bool TryGetMemberStore(ulong guildId, out ReferenceStore<SocketGuildUser, IMemberModel, ulong, IGuildUser> store) |
|
|
|
=> _memberStores.TryGetValue(guildId, out store); |
|
|
|
|
|
|
|
return socketPresence; |
|
|
|
public async ValueTask<ReferenceStore<SocketGuildUser, IMemberModel, ulong, IGuildUser>> GetMemberStoreAsync(ulong guildId) |
|
|
|
{ |
|
|
|
if (_memberStores.TryGetValue(guildId, out var store)) |
|
|
|
return store; |
|
|
|
|
|
|
|
await _memberStoreLock.WaitAsync().ConfigureAwait(false); |
|
|
|
|
|
|
|
try |
|
|
|
{ |
|
|
|
store = new ReferenceStore<SocketGuildUser, IMemberModel, ulong, IGuildUser>( |
|
|
|
_cacheProvider, |
|
|
|
m => SocketGuildUser.Create(guildId, _client, m), |
|
|
|
async (id, options) => await _client.Rest.GetGuildUserAsync(guildId, id, options).ConfigureAwait(false), |
|
|
|
AllowSyncWaits); |
|
|
|
|
|
|
|
await store.InitializeAsync(guildId).ConfigureAwait(false); |
|
|
|
|
|
|
|
_memberStores.TryAdd(guildId, store); |
|
|
|
return store; |
|
|
|
} |
|
|
|
finally |
|
|
|
{ |
|
|
|
_memberStoreLock.Release(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public async Task<ReferenceStore<SocketThreadUser, IThreadMemberModel, ulong, IThreadUser>> GetThreadMemberStoreAsync(ulong threadId, ulong guildId) |
|
|
|
{ |
|
|
|
if (_threadMemberStores.TryGetValue(threadId, out var store)) |
|
|
|
return store; |
|
|
|
|
|
|
|
await _threadMemberLock.WaitAsync().ConfigureAwait(false); |
|
|
|
|
|
|
|
try |
|
|
|
{ |
|
|
|
store = new ReferenceStore<SocketThreadUser, IThreadMemberModel, ulong, IThreadUser>( |
|
|
|
_cacheProvider, |
|
|
|
m => SocketThreadUser.Create(_client, guildId, threadId, m), |
|
|
|
async (id, options) => await ThreadHelper.GetUserAsync(id, _client.GetChannel(threadId) as SocketThreadChannel, _client, options).ConfigureAwait(false), |
|
|
|
AllowSyncWaits); |
|
|
|
|
|
|
|
await store.InitializeAsync().ConfigureAwait(false); |
|
|
|
|
|
|
|
_threadMemberStores.TryAdd(threadId, store); |
|
|
|
return store; |
|
|
|
} |
|
|
|
finally |
|
|
|
{ |
|
|
|
_threadMemberLock.Release(); |
|
|
|
} |
|
|
|
} |
|
|
|
#endregion |
|
|
|
} |
|
|
|
} |