| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Threading; | using System.Threading; | ||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| @@ -13,6 +14,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| public int ThreadCount { get; } | public int ThreadCount { get; } | ||||
| public Thread[] Threads { get; } | public Thread[] Threads { get; } | ||||
| public Exception[] Exceptions { get; } | |||||
| private readonly SemaphoreSlim barrier_threadstarted; | private readonly SemaphoreSlim barrier_threadstarted; | ||||
| private readonly ManualResetEventSlim barrier_corestart; | private readonly ManualResetEventSlim barrier_corestart; | ||||
| private readonly SemaphoreSlim done_barrier2; | private readonly SemaphoreSlim done_barrier2; | ||||
| @@ -57,6 +59,7 @@ namespace TensorFlowNET.UnitTest | |||||
| throw new ArgumentOutOfRangeException(nameof(threadCount)); | throw new ArgumentOutOfRangeException(nameof(threadCount)); | ||||
| ThreadCount = threadCount; | ThreadCount = threadCount; | ||||
| Threads = new Thread[ThreadCount]; | Threads = new Thread[ThreadCount]; | ||||
| Exceptions = new Exception[ThreadCount]; | |||||
| done_barrier2 = new SemaphoreSlim(0, threadCount); | done_barrier2 = new SemaphoreSlim(0, threadCount); | ||||
| barrier_corestart = new ManualResetEventSlim(); | barrier_corestart = new ManualResetEventSlim(); | ||||
| barrier_threadstarted = new SemaphoreSlim(0, threadCount); | barrier_threadstarted = new SemaphoreSlim(0, threadCount); | ||||
| @@ -72,28 +75,53 @@ namespace TensorFlowNET.UnitTest | |||||
| if (ThreadCount == 1) | if (ThreadCount == 1) | ||||
| { | { | ||||
| Exception ex = null; | |||||
| new Thread(() => | new Thread(() => | ||||
| { | { | ||||
| workloads[0](0); | |||||
| done_barrier2.Release(1); | |||||
| try | |||||
| { | |||||
| workloads[0](0); | |||||
| } catch (Exception e) | |||||
| { | |||||
| if (Debugger.IsAttached) | |||||
| throw; | |||||
| ex = e; | |||||
| } finally | |||||
| { | |||||
| done_barrier2.Release(1); | |||||
| } | |||||
| }).Start(); | }).Start(); | ||||
| done_barrier2.Wait(); | done_barrier2.Wait(); | ||||
| if (ex != null) | |||||
| throw new Exception($"Thread 0 has failed: ", ex); | |||||
| PostRun?.Invoke(this); | PostRun?.Invoke(this); | ||||
| return; | return; | ||||
| } | } | ||||
| //thread core | //thread core | ||||
| void ThreadCore(MultiThreadedTestDelegate core, int threadid) | |||||
| Exception ThreadCore(MultiThreadedTestDelegate core, int threadid) | |||||
| { | { | ||||
| barrier_threadstarted.Release(1); | barrier_threadstarted.Release(1); | ||||
| barrier_corestart.Wait(); | barrier_corestart.Wait(); | ||||
| //workload | //workload | ||||
| core(threadid); | |||||
| try | |||||
| { | |||||
| core(threadid); | |||||
| } catch (Exception e) | |||||
| { | |||||
| if (Debugger.IsAttached) | |||||
| throw; | |||||
| return e; | |||||
| } finally | |||||
| { | |||||
| done_barrier2.Release(1); | |||||
| } | |||||
| done_barrier2.Release(1); | |||||
| return null; | |||||
| } | } | ||||
| //initialize all threads | //initialize all threads | ||||
| @@ -103,7 +131,7 @@ namespace TensorFlowNET.UnitTest | |||||
| for (int i = 0; i < ThreadCount; i++) | for (int i = 0; i < ThreadCount; i++) | ||||
| { | { | ||||
| var i_local = i; | var i_local = i; | ||||
| Threads[i] = new Thread(() => ThreadCore(workload, i_local)); | |||||
| Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local)); | |||||
| } | } | ||||
| } else | } else | ||||
| { | { | ||||
| @@ -111,7 +139,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| var i_local = i; | var i_local = i; | ||||
| var workload = workloads[i_local % workloads.Length]; | var workload = workloads[i_local % workloads.Length]; | ||||
| Threads[i] = new Thread(() => ThreadCore(workload, i_local)); | |||||
| Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local)); | |||||
| } | } | ||||
| } | } | ||||
| @@ -126,6 +154,11 @@ namespace TensorFlowNET.UnitTest | |||||
| //wait for threads to finish | //wait for threads to finish | ||||
| for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait(); | for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait(); | ||||
| //handle fails | |||||
| for (int i = 0; i < ThreadCount; i++) | |||||
| if (Exceptions[i] != null) | |||||
| throw new Exception($"Thread {i} has failed: ", Exceptions[i]); | |||||
| //checks after ended | //checks after ended | ||||
| PostRun?.Invoke(this); | PostRun?.Invoke(this); | ||||
| } | } | ||||