You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConcurrentHashSet.cs 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Collections.ObjectModel;
  5. using System.Diagnostics;
  6. using System.Threading;
  7. namespace Discord
  8. {
  9. //Based on https://github.com/dotnet/corefx/blob/d0dc5fc099946adc1035b34a8b1f6042eddb0c75/src/System.Threading.Tasks.Parallel/src/System/Threading/PlatformHelper.cs
  10. //Copyright (c) .NET Foundation and Contributors
  11. internal static class ConcurrentHashSet
  12. {
  13. private const int PROCESSOR_COUNT_REFRESH_INTERVAL_MS = 30000;
  14. private static volatile int s_processorCount;
  15. private static volatile int s_lastProcessorCountRefreshTicks;
  16. public static int DefaultConcurrencyLevel
  17. {
  18. get
  19. {
  20. int now = Environment.TickCount;
  21. if (s_processorCount == 0 || (now - s_lastProcessorCountRefreshTicks) >= PROCESSOR_COUNT_REFRESH_INTERVAL_MS)
  22. {
  23. s_processorCount = Environment.ProcessorCount;
  24. s_lastProcessorCountRefreshTicks = now;
  25. }
  26. return s_processorCount;
  27. }
  28. }
  29. }
  30. //Based on https://github.com/dotnet/corefx/blob/master/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs
  31. //Copyright (c) .NET Foundation and Contributors
  32. [DebuggerDisplay("Count = {Count}")]
  33. internal class ConcurrentHashSet<T> : IReadOnlyCollection<T>
  34. {
  35. private sealed class Tables
  36. {
  37. internal readonly Node[] _buckets;
  38. internal readonly object[] _locks;
  39. internal volatile int[] _countPerLock;
  40. internal Tables(Node[] buckets, object[] locks, int[] countPerLock)
  41. {
  42. _buckets = buckets;
  43. _locks = locks;
  44. _countPerLock = countPerLock;
  45. }
  46. }
  47. private sealed class Node
  48. {
  49. internal readonly T _value;
  50. internal volatile Node _next;
  51. internal readonly int _hashcode;
  52. internal Node(T key, int hashcode, Node next)
  53. {
  54. _value = key;
  55. _next = next;
  56. _hashcode = hashcode;
  57. }
  58. }
  59. private const int DefaultCapacity = 31;
  60. private const int MaxLockNumber = 1024;
  61. private static int GetBucket(int hashcode, int bucketCount)
  62. {
  63. int bucketNo = (hashcode & 0x7fffffff) % bucketCount;
  64. return bucketNo;
  65. }
  66. private static void GetBucketAndLockNo(int hashcode, out int bucketNo, out int lockNo, int bucketCount, int lockCount)
  67. {
  68. bucketNo = (hashcode & 0x7fffffff) % bucketCount;
  69. lockNo = bucketNo % lockCount;
  70. }
  71. private static int DefaultConcurrencyLevel => ConcurrentHashSet.DefaultConcurrencyLevel;
  72. private volatile Tables _tables;
  73. private readonly IEqualityComparer<T> _comparer;
  74. private readonly bool _growLockArray;
  75. private int _budget;
  76. public int Count
  77. {
  78. get
  79. {
  80. int count = 0;
  81. int acquiredLocks = 0;
  82. try
  83. {
  84. AcquireAllLocks(ref acquiredLocks);
  85. for (int i = 0; i < _tables._countPerLock.Length; i++)
  86. count += _tables._countPerLock[i];
  87. }
  88. finally { ReleaseLocks(0, acquiredLocks); }
  89. return count;
  90. }
  91. }
  92. public bool IsEmpty
  93. {
  94. get
  95. {
  96. int acquiredLocks = 0;
  97. try
  98. {
  99. // Acquire all locks
  100. AcquireAllLocks(ref acquiredLocks);
  101. for (int i = 0; i < _tables._countPerLock.Length; i++)
  102. {
  103. if (_tables._countPerLock[i] != 0)
  104. return false;
  105. }
  106. }
  107. finally { ReleaseLocks(0, acquiredLocks); }
  108. return true;
  109. }
  110. }
  111. public ReadOnlyCollection<T> Values
  112. {
  113. get
  114. {
  115. int locksAcquired = 0;
  116. try
  117. {
  118. AcquireAllLocks(ref locksAcquired);
  119. List<T> values = new List<T>();
  120. for (int i = 0; i < _tables._buckets.Length; i++)
  121. {
  122. Node current = _tables._buckets[i];
  123. while (current != null)
  124. {
  125. values.Add(current._value);
  126. current = current._next;
  127. }
  128. }
  129. return new ReadOnlyCollection<T>(values);
  130. }
  131. finally { ReleaseLocks(0, locksAcquired); }
  132. }
  133. }
  134. public ConcurrentHashSet()
  135. : this(DefaultConcurrencyLevel, DefaultCapacity, true, EqualityComparer<T>.Default) { }
  136. public ConcurrentHashSet(int concurrencyLevel, int capacity)
  137. : this(concurrencyLevel, capacity, false, EqualityComparer<T>.Default) { }
  138. public ConcurrentHashSet(IEnumerable<T> collection)
  139. : this(collection, EqualityComparer<T>.Default) { }
  140. public ConcurrentHashSet(IEqualityComparer<T> comparer)
  141. : this(DefaultConcurrencyLevel, DefaultCapacity, true, comparer) { }
  142. public ConcurrentHashSet(IEnumerable<T> collection, IEqualityComparer<T> comparer)
  143. : this(comparer)
  144. {
  145. if (collection == null) throw new ArgumentNullException(paramName: nameof(collection));
  146. InitializeFromCollection(collection);
  147. }
  148. public ConcurrentHashSet(int concurrencyLevel, IEnumerable<T> collection, IEqualityComparer<T> comparer)
  149. : this(concurrencyLevel, DefaultCapacity, false, comparer)
  150. {
  151. if (collection == null) throw new ArgumentNullException(paramName: nameof(collection));
  152. if (comparer == null) throw new ArgumentNullException(paramName: nameof(comparer));
  153. InitializeFromCollection(collection);
  154. }
  155. public ConcurrentHashSet(int concurrencyLevel, int capacity, IEqualityComparer<T> comparer)
  156. : this(concurrencyLevel, capacity, false, comparer) { }
  157. internal ConcurrentHashSet(int concurrencyLevel, int capacity, bool growLockArray, IEqualityComparer<T> comparer)
  158. {
  159. if (concurrencyLevel < 1) throw new ArgumentOutOfRangeException(paramName: nameof(concurrencyLevel));
  160. if (capacity < 0) throw new ArgumentOutOfRangeException(paramName: nameof(capacity));
  161. if (comparer == null) throw new ArgumentNullException(paramName: nameof(comparer));
  162. if (capacity < concurrencyLevel)
  163. capacity = concurrencyLevel;
  164. object[] locks = new object[concurrencyLevel];
  165. for (int i = 0; i < locks.Length; i++)
  166. locks[i] = new object();
  167. int[] countPerLock = new int[locks.Length];
  168. Node[] buckets = new Node[capacity];
  169. _tables = new Tables(buckets, locks, countPerLock);
  170. _comparer = comparer;
  171. _growLockArray = growLockArray;
  172. _budget = buckets.Length / locks.Length;
  173. }
  174. private void InitializeFromCollection(IEnumerable<T> collection)
  175. {
  176. foreach (var value in collection)
  177. {
  178. if (value == null) throw new ArgumentNullException(paramName: "key");
  179. if (!TryAddInternal(value, _comparer.GetHashCode(value), false))
  180. throw new ArgumentException();
  181. }
  182. if (_budget == 0)
  183. _budget = _tables._buckets.Length / _tables._locks.Length;
  184. }
  185. public bool ContainsKey(T value)
  186. {
  187. if (value == null) throw new ArgumentNullException(paramName: "key");
  188. return ContainsKeyInternal(value, _comparer.GetHashCode(value));
  189. }
  190. private bool ContainsKeyInternal(T value, int hashcode)
  191. {
  192. Tables tables = _tables;
  193. int bucketNo = GetBucket(hashcode, tables._buckets.Length);
  194. Node n = Volatile.Read(ref tables._buckets[bucketNo]);
  195. while (n != null)
  196. {
  197. if (hashcode == n._hashcode && _comparer.Equals(n._value, value))
  198. return true;
  199. n = n._next;
  200. }
  201. return false;
  202. }
  203. public bool TryAdd(T value)
  204. {
  205. if (value == null) throw new ArgumentNullException(paramName: "key");
  206. return TryAddInternal(value, _comparer.GetHashCode(value), true);
  207. }
  208. private bool TryAddInternal(T value, int hashcode, bool acquireLock)
  209. {
  210. while (true)
  211. {
  212. Tables tables = _tables;
  213. GetBucketAndLockNo(hashcode, out int bucketNo, out int lockNo, tables._buckets.Length, tables._locks.Length);
  214. bool resizeDesired = false;
  215. bool lockTaken = false;
  216. try
  217. {
  218. if (acquireLock)
  219. Monitor.Enter(tables._locks[lockNo], ref lockTaken);
  220. if (tables != _tables)
  221. continue;
  222. Node prev = null;
  223. for (Node node = tables._buckets[bucketNo]; node != null; node = node._next)
  224. {
  225. if (hashcode == node._hashcode && _comparer.Equals(node._value, value))
  226. return false;
  227. prev = node;
  228. }
  229. Volatile.Write(ref tables._buckets[bucketNo], new Node(value, hashcode, tables._buckets[bucketNo]));
  230. checked { tables._countPerLock[lockNo]++; }
  231. if (tables._countPerLock[lockNo] > _budget)
  232. resizeDesired = true;
  233. }
  234. finally
  235. {
  236. if (lockTaken)
  237. Monitor.Exit(tables._locks[lockNo]);
  238. }
  239. if (resizeDesired)
  240. GrowTable(tables);
  241. return true;
  242. }
  243. }
  244. public bool TryRemove(T value)
  245. {
  246. if (value == null) throw new ArgumentNullException(paramName: "key");
  247. return TryRemoveInternal(value);
  248. }
  249. private bool TryRemoveInternal(T value)
  250. {
  251. int hashcode = _comparer.GetHashCode(value);
  252. while (true)
  253. {
  254. Tables tables = _tables;
  255. GetBucketAndLockNo(hashcode, out int bucketNo, out int lockNo, tables._buckets.Length, tables._locks.Length);
  256. lock (tables._locks[lockNo])
  257. {
  258. if (tables != _tables)
  259. continue;
  260. Node prev = null;
  261. for (Node curr = tables._buckets[bucketNo]; curr != null; curr = curr._next)
  262. {
  263. if (hashcode == curr._hashcode && _comparer.Equals(curr._value, value))
  264. {
  265. if (prev == null)
  266. Volatile.Write(ref tables._buckets[bucketNo], curr._next);
  267. else
  268. prev._next = curr._next;
  269. value = curr._value;
  270. tables._countPerLock[lockNo]--;
  271. return true;
  272. }
  273. prev = curr;
  274. }
  275. }
  276. value = default(T);
  277. return false;
  278. }
  279. }
  280. public void Clear()
  281. {
  282. int locksAcquired = 0;
  283. try
  284. {
  285. AcquireAllLocks(ref locksAcquired);
  286. Tables newTables = new Tables(new Node[DefaultCapacity], _tables._locks, new int[_tables._countPerLock.Length]);
  287. _tables = newTables;
  288. _budget = Math.Max(1, newTables._buckets.Length / newTables._locks.Length);
  289. }
  290. finally
  291. {
  292. ReleaseLocks(0, locksAcquired);
  293. }
  294. }
  295. public IEnumerator<T> GetEnumerator()
  296. {
  297. Node[] buckets = _tables._buckets;
  298. for (int i = 0; i < buckets.Length; i++)
  299. {
  300. Node current = Volatile.Read(ref buckets[i]);
  301. while (current != null)
  302. {
  303. yield return current._value;
  304. current = current._next;
  305. }
  306. }
  307. }
  308. IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
  309. private void GrowTable(Tables tables)
  310. {
  311. const int MaxArrayLength = 0X7FEFFFFF;
  312. int locksAcquired = 0;
  313. try
  314. {
  315. AcquireLocks(0, 1, ref locksAcquired);
  316. if (tables != _tables)
  317. return;
  318. long approxCount = 0;
  319. for (int i = 0; i < tables._countPerLock.Length; i++)
  320. approxCount += tables._countPerLock[i];
  321. if (approxCount < tables._buckets.Length / 4)
  322. {
  323. _budget = 2 * _budget;
  324. if (_budget < 0)
  325. _budget = int.MaxValue;
  326. return;
  327. }
  328. int newLength = 0;
  329. bool maximizeTableSize = false;
  330. try
  331. {
  332. checked
  333. {
  334. newLength = tables._buckets.Length * 2 + 1;
  335. while (newLength % 3 == 0 || newLength % 5 == 0 || newLength % 7 == 0)
  336. newLength += 2;
  337. if (newLength > MaxArrayLength)
  338. maximizeTableSize = true;
  339. }
  340. }
  341. catch (OverflowException)
  342. {
  343. maximizeTableSize = true;
  344. }
  345. if (maximizeTableSize)
  346. {
  347. newLength = MaxArrayLength;
  348. _budget = int.MaxValue;
  349. }
  350. AcquireLocks(1, tables._locks.Length, ref locksAcquired);
  351. object[] newLocks = tables._locks;
  352. if (_growLockArray && tables._locks.Length < MaxLockNumber)
  353. {
  354. newLocks = new object[tables._locks.Length * 2];
  355. Array.Copy(tables._locks, 0, newLocks, 0, tables._locks.Length);
  356. for (int i = tables._locks.Length; i < newLocks.Length; i++)
  357. newLocks[i] = new object();
  358. }
  359. Node[] newBuckets = new Node[newLength];
  360. int[] newCountPerLock = new int[newLocks.Length];
  361. for (int i = 0; i < tables._buckets.Length; i++)
  362. {
  363. Node current = tables._buckets[i];
  364. while (current != null)
  365. {
  366. Node next = current._next;
  367. GetBucketAndLockNo(current._hashcode, out int newBucketNo, out int newLockNo, newBuckets.Length, newLocks.Length);
  368. newBuckets[newBucketNo] = new Node(current._value, current._hashcode, newBuckets[newBucketNo]);
  369. checked { newCountPerLock[newLockNo]++; }
  370. current = next;
  371. }
  372. }
  373. _budget = Math.Max(1, newBuckets.Length / newLocks.Length);
  374. _tables = new Tables(newBuckets, newLocks, newCountPerLock);
  375. }
  376. finally { ReleaseLocks(0, locksAcquired); }
  377. }
  378. private void AcquireAllLocks(ref int locksAcquired)
  379. {
  380. AcquireLocks(0, 1, ref locksAcquired);
  381. AcquireLocks(1, _tables._locks.Length, ref locksAcquired);
  382. }
  383. private void AcquireLocks(int fromInclusive, int toExclusive, ref int locksAcquired)
  384. {
  385. object[] locks = _tables._locks;
  386. for (int i = fromInclusive; i < toExclusive; i++)
  387. {
  388. bool lockTaken = false;
  389. try
  390. {
  391. Monitor.Enter(locks[i], ref lockTaken);
  392. }
  393. finally
  394. {
  395. if (lockTaken)
  396. locksAcquired++;
  397. }
  398. }
  399. }
  400. private void ReleaseLocks(int fromInclusive, int toExclusive)
  401. {
  402. for (int i = fromInclusive; i < toExclusive; i++)
  403. Monitor.Exit(_tables._locks[i]);
  404. }
  405. }
  406. }