Browse Source

fix: error when using graph in multi-threads.

pull/1068/head
Yaohui Liu 2 years ago
parent
commit
71ade1bc8c
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
2 changed files with 44 additions and 2 deletions
  1. +3
    -2
      src/TensorFlowNET.Core/Device/DeviceSpec.cs
  2. +41
    -0
      test/TensorFlowNET.UnitTest/Basics/ThreadSafeTest.cs

+ 3
- 2
src/TensorFlowNET.Core/Device/DeviceSpec.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -7,8 +8,8 @@ namespace Tensorflow.Device
{ {
public class DeviceSpec public class DeviceSpec
{ {
private static Dictionary<string, Components> _STRING_TO_COMPONENTS_CACHE = new();
private static Dictionary<Components, string> _COMPONENTS_TO_STRING_CACHE = new();
private static ConcurrentDictionary<string, Components> _STRING_TO_COMPONENTS_CACHE = new();
private static ConcurrentDictionary<Components, string> _COMPONENTS_TO_STRING_CACHE = new();
private string _job; private string _job;
private int _replica; private int _replica;
private int _task; private int _task;


+ 41
- 0
test/TensorFlowNET.UnitTest/Basics/ThreadSafeTest.cs View File

@@ -0,0 +1,41 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Basics
{
[TestClass]
public class ThreadSafeTest
{
[TestMethod]
public void GraphWithMultiThreads()
{
List<Thread> threads = new List<Thread>();

const int THREADS_COUNT = 5;

for (int t = 0; t < THREADS_COUNT; t++)
{
Thread thread = new Thread(() =>
{
Graph g = new Graph();
Session session = new Session(g);
session.as_default();
var input = tf.placeholder(tf.int32, shape: new Shape(6));
var op = tf.reshape(input, new int[] { 2, 3 });
});
thread.Start();
threads.Add(thread);
}

threads.ForEach(t => t.Join());
}
}
}

Loading…
Cancel
Save