| @@ -38,9 +38,9 @@ namespace Tensorflow | |||||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | ||||
| { | { | ||||
| _graph = g is null ? ops.get_default_graph() : g; | |||||
| _graph = g ?? ops.get_default_graph(); | |||||
| _graph.as_default(); | _graph.as_default(); | ||||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||||
| _target = Encoding.UTF8.GetBytes(target); | |||||
| SessionOptions newOpts = opts ?? new SessionOptions(); | SessionOptions newOpts = opts ?? new SessionOptions(); | ||||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||||
| public Session as_default() | public Session as_default() | ||||
| { | { | ||||
| tf.defaultSession = this; | |||||
| tf._defaultSessionFactory.Value = this; | |||||
| return this; | return this; | ||||
| } | } | ||||
| @@ -19,6 +19,7 @@ using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Threading; | |||||
| using NumSharp; | using NumSharp; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -26,6 +27,10 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class ops | public partial class ops | ||||
| { | { | ||||
| private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||||
| public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value; | |||||
| public static int tensor_id(Tensor tensor) | public static int tensor_id(Tensor tensor) | ||||
| { | { | ||||
| return tensor.Id; | return tensor.Id; | ||||
| @@ -72,8 +77,6 @@ namespace Tensorflow | |||||
| return get_default_graph().get_collection_ref(key); | return get_default_graph().get_collection_ref(key); | ||||
| } | } | ||||
| public static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the default graph for the current thread. | /// Returns the default graph for the current thread. | ||||
| /// | /// | ||||
| @@ -387,8 +390,6 @@ namespace Tensorflow | |||||
| /// <returns>The default `Session` being used in the current thread.</returns> | /// <returns>The default `Session` being used in the current thread.</returns> | ||||
| public static Session get_default_session() | public static Session get_default_session() | ||||
| { | { | ||||
| if (tf.defaultSession == null) | |||||
| tf.defaultSession = tf.Session(); | |||||
| return tf.defaultSession; | return tf.defaultSession; | ||||
| } | } | ||||
| @@ -14,12 +14,15 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Threading; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class tensorflow : IObjectLife | public partial class tensorflow : IObjectLife | ||||
| { | { | ||||
| protected internal readonly ThreadLocal<Session> _defaultSessionFactory; | |||||
| public TF_DataType @byte = TF_DataType.TF_UINT8; | public TF_DataType @byte = TF_DataType.TF_UINT8; | ||||
| public TF_DataType @sbyte = TF_DataType.TF_INT8; | public TF_DataType @sbyte = TF_DataType.TF_INT8; | ||||
| public TF_DataType int16 = TF_DataType.TF_INT16; | public TF_DataType int16 = TF_DataType.TF_INT16; | ||||
| @@ -34,7 +37,13 @@ namespace Tensorflow | |||||
| public Context context = new Context(new ContextOptions(), new Status()); | public Context context = new Context(new ContextOptions(), new Status()); | ||||
| public Session defaultSession; | |||||
| public tensorflow() | |||||
| { | |||||
| _defaultSessionFactory = new ThreadLocal<Session>(Session); | |||||
| } | |||||
| public Session defaultSession => _defaultSessionFactory.Value; | |||||
| public RefVariable Variable<T>(T data, | public RefVariable Variable<T>(T data, | ||||
| bool trainable = true, | bool trainable = true, | ||||
| @@ -0,0 +1,71 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using FluentAssertions; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| [TestClass] | |||||
| public class MultithreadingTests | |||||
| { | |||||
| [TestMethod] | |||||
| public void SessionCreation() | |||||
| { | |||||
| tf.Session(); //create one to increase next id to 1. | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| tf.peak_default_graph().Should().BeNull(); | |||||
| //tf.Session created an other graph | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var default_graph = tf.peak_default_graph(); | |||||
| var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||||
| sess_graph.Should().NotBeNull(); | |||||
| default_graph.Should().NotBeNull() | |||||
| .And.BeEquivalentTo(sess_graph); | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void GraphCreation() | |||||
| { | |||||
| tf.Graph(); //create one to increase next id to 1. | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| tf.peak_default_graph().Should().BeNull(); | |||||
| var beforehand = tf.get_default_graph(); //this should create default automatically. | |||||
| beforehand.graph_key.Should().NotContain("0", "Already created a graph in an other thread."); | |||||
| tf.peak_default_graph().Should().NotBeNull(); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var default_graph = tf.peak_default_graph(); | |||||
| var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||||
| sess_graph.Should().NotBeNull(); | |||||
| default_graph.Should().NotBeNull() | |||||
| .And.BeEquivalentTo(sess_graph) | |||||
| .And.BeEquivalentTo(beforehand); | |||||
| Console.WriteLine($"{tid}-{default_graph.graph_key}"); | |||||
| //var result = sess.run(new object[] {g, a}); | |||||
| //var actualDeriv = result[0].GetData<float>()[0]; | |||||
| //var actual = result[1].GetData<float>()[0]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest | |||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| { | { | ||||
| var result = c.eval(sess); | var result = c.eval(sess); | ||||
| Assert.AreEqual(6, result.Data<double>()[0]); | |||||
| Assert.AreEqual(6, result.GetAtIndex<double>(0)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||