| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Diagnostics; | |||
| using System.Threading; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| @@ -13,6 +14,7 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| public int ThreadCount { get; } | |||
| public Thread[] Threads { get; } | |||
| public Exception[] Exceptions { get; } | |||
| private readonly SemaphoreSlim barrier_threadstarted; | |||
| private readonly ManualResetEventSlim barrier_corestart; | |||
| private readonly SemaphoreSlim done_barrier2; | |||
| @@ -57,6 +59,7 @@ namespace TensorFlowNET.UnitTest | |||
| throw new ArgumentOutOfRangeException(nameof(threadCount)); | |||
| ThreadCount = threadCount; | |||
| Threads = new Thread[ThreadCount]; | |||
| Exceptions = new Exception[ThreadCount]; | |||
| done_barrier2 = new SemaphoreSlim(0, threadCount); | |||
| barrier_corestart = new ManualResetEventSlim(); | |||
| barrier_threadstarted = new SemaphoreSlim(0, threadCount); | |||
| @@ -72,28 +75,53 @@ namespace TensorFlowNET.UnitTest | |||
| if (ThreadCount == 1) | |||
| { | |||
| Exception ex = null; | |||
| new Thread(() => | |||
| { | |||
| workloads[0](0); | |||
| done_barrier2.Release(1); | |||
| try | |||
| { | |||
| workloads[0](0); | |||
| } catch (Exception e) | |||
| { | |||
| if (Debugger.IsAttached) | |||
| throw; | |||
| ex = e; | |||
| } finally | |||
| { | |||
| done_barrier2.Release(1); | |||
| } | |||
| }).Start(); | |||
| done_barrier2.Wait(); | |||
| if (ex != null) | |||
| throw new Exception($"Thread 0 has failed: ", ex); | |||
| PostRun?.Invoke(this); | |||
| return; | |||
| } | |||
| //thread core | |||
| void ThreadCore(MultiThreadedTestDelegate core, int threadid) | |||
| Exception ThreadCore(MultiThreadedTestDelegate core, int threadid) | |||
| { | |||
| barrier_threadstarted.Release(1); | |||
| barrier_corestart.Wait(); | |||
| //workload | |||
| core(threadid); | |||
| try | |||
| { | |||
| core(threadid); | |||
| } catch (Exception e) | |||
| { | |||
| if (Debugger.IsAttached) | |||
| throw; | |||
| return e; | |||
| } finally | |||
| { | |||
| done_barrier2.Release(1); | |||
| } | |||
| done_barrier2.Release(1); | |||
| return null; | |||
| } | |||
| //initialize all threads | |||
| @@ -103,7 +131,7 @@ namespace TensorFlowNET.UnitTest | |||
| for (int i = 0; i < ThreadCount; i++) | |||
| { | |||
| var i_local = i; | |||
| Threads[i] = new Thread(() => ThreadCore(workload, i_local)); | |||
| Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local)); | |||
| } | |||
| } else | |||
| { | |||
| @@ -111,7 +139,7 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| var i_local = i; | |||
| var workload = workloads[i_local % workloads.Length]; | |||
| Threads[i] = new Thread(() => ThreadCore(workload, i_local)); | |||
| Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local)); | |||
| } | |||
| } | |||
| @@ -126,6 +154,11 @@ namespace TensorFlowNET.UnitTest | |||
| //wait for threads to finish | |||
| for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait(); | |||
| //handle fails | |||
| for (int i = 0; i < ThreadCount; i++) | |||
| if (Exceptions[i] != null) | |||
| throw new Exception($"Thread {i} has failed: ", Exceptions[i]); | |||
| //checks after ended | |||
| PostRun?.Invoke(this); | |||
| } | |||