From f27b0c892f24443e9653c6b1cf3d39e33c1973b3 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Fri, 30 Aug 2019 21:05:14 +0300 Subject: [PATCH] Graph, Session: defaulting was changed to thread-wide using ThreadLocal --- .../Sessions/BaseSession.cs | 4 +- src/TensorFlowNET.Core/Sessions/Session.cs | 2 +- src/TensorFlowNET.Core/ops.cs | 9 +-- src/TensorFlowNET.Core/tensorflow.cs | 11 ++- .../MultithreadingTests.cs | 71 +++++++++++++++++++ test/TensorFlowNET.UnitTest/SessionTest.cs | 2 +- 6 files changed, 90 insertions(+), 9 deletions(-) create mode 100644 test/TensorFlowNET.UnitTest/MultithreadingTests.cs diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index c3368120..bd89107f 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -38,9 +38,9 @@ namespace Tensorflow public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) { - _graph = g is null ? ops.get_default_graph() : g; + _graph = g ?? ops.get_default_graph(); _graph.as_default(); - _target = UTF8Encoding.UTF8.GetBytes(target); + _target = Encoding.UTF8.GetBytes(target); SessionOptions newOpts = opts ?? new SessionOptions(); diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index ec2e443f..649e7806 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -42,7 +42,7 @@ namespace Tensorflow public Session as_default() { - tf.defaultSession = this; + tf._defaultSessionFactory.Value = this; return this; } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 18184edd..b5c4c0d1 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Runtime.InteropServices; using Google.Protobuf; using System.Linq; +using System.Threading; using NumSharp; using static Tensorflow.Binding; @@ -26,6 +27,10 @@ namespace Tensorflow { public partial class ops { + private static readonly ThreadLocal _defaultGraphFactory = new ThreadLocal(() => new DefaultGraphStack()); + + public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value; + public static int tensor_id(Tensor tensor) { return tensor.Id; @@ -72,8 +77,6 @@ namespace Tensorflow return get_default_graph().get_collection_ref(key); } - public static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); - /// /// Returns the default graph for the current thread. /// @@ -387,8 +390,6 @@ namespace Tensorflow /// The default `Session` being used in the current thread. public static Session get_default_session() { - if (tf.defaultSession == null) - tf.defaultSession = tf.Session(); return tf.defaultSession; } diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index ca903844..bdb2f537 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -14,12 +14,15 @@ limitations under the License. ******************************************************************************/ +using System.Threading; using Tensorflow.Eager; namespace Tensorflow { public partial class tensorflow : IObjectLife { + protected internal readonly ThreadLocal _defaultSessionFactory; + public TF_DataType @byte = TF_DataType.TF_UINT8; public TF_DataType @sbyte = TF_DataType.TF_INT8; public TF_DataType int16 = TF_DataType.TF_INT16; @@ -34,7 +37,13 @@ namespace Tensorflow public Context context = new Context(new ContextOptions(), new Status()); - public Session defaultSession; + + public tensorflow() + { + _defaultSessionFactory = new ThreadLocal(Session); + } + + public Session defaultSession => _defaultSessionFactory.Value; public RefVariable Variable(T data, bool trainable = true, diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs new file mode 100644 index 00000000..b889e267 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs @@ -0,0 +1,71 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using FluentAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class MultithreadingTests + { + [TestMethod] + public void SessionCreation() + { + tf.Session(); //create one to increase next id to 1. + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + + //tf.Session created an other graph + using (var sess = tf.Session()) + { + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.GetPrivate("_graph"); + sess_graph.Should().NotBeNull(); + default_graph.Should().NotBeNull() + .And.BeEquivalentTo(sess_graph); + } + } + } + + [TestMethod] + public void GraphCreation() + { + tf.Graph(); //create one to increase next id to 1. + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + var beforehand = tf.get_default_graph(); //this should create default automatically. + beforehand.graph_key.Should().NotContain("0", "Already created a graph in an other thread."); + tf.peak_default_graph().Should().NotBeNull(); + + using (var sess = tf.Session()) + { + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.GetPrivate("_graph"); + sess_graph.Should().NotBeNull(); + default_graph.Should().NotBeNull() + .And.BeEquivalentTo(sess_graph) + .And.BeEquivalentTo(beforehand); + + Console.WriteLine($"{tid}-{default_graph.graph_key}"); + + //var result = sess.run(new object[] {g, a}); + //var actualDeriv = result[0].GetData()[0]; + //var actual = result[1].GetData()[0]; + } + } + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index 45005a59..91c75a13 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var result = c.eval(sess); - Assert.AreEqual(6, result.Data()[0]); + Assert.AreEqual(6, result.GetAtIndex(0)); } } }