| @@ -38,9 +38,9 @@ namespace Tensorflow | |||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||
| { | |||
| _graph = g is null ? ops.get_default_graph() : g; | |||
| _graph = g ?? ops.get_default_graph(); | |||
| _graph.as_default(); | |||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||
| _target = Encoding.UTF8.GetBytes(target); | |||
| SessionOptions newOpts = opts ?? new SessionOptions(); | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||
| public Session as_default() | |||
| { | |||
| tf.defaultSession = this; | |||
| tf._defaultSessionFactory.Value = this; | |||
| return this; | |||
| } | |||
| @@ -19,6 +19,7 @@ using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using Google.Protobuf; | |||
| using System.Linq; | |||
| using System.Threading; | |||
| using NumSharp; | |||
| using static Tensorflow.Binding; | |||
| @@ -26,6 +27,10 @@ namespace Tensorflow | |||
| { | |||
| public partial class ops | |||
| { | |||
| private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||
| public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value; | |||
| public static int tensor_id(Tensor tensor) | |||
| { | |||
| return tensor.Id; | |||
| @@ -72,8 +77,6 @@ namespace Tensorflow | |||
| return get_default_graph().get_collection_ref(key); | |||
| } | |||
| public static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); | |||
| /// <summary> | |||
| /// Returns the default graph for the current thread. | |||
| /// | |||
| @@ -387,8 +390,6 @@ namespace Tensorflow | |||
| /// <returns>The default `Session` being used in the current thread.</returns> | |||
| public static Session get_default_session() | |||
| { | |||
| if (tf.defaultSession == null) | |||
| tf.defaultSession = tf.Session(); | |||
| return tf.defaultSession; | |||
| } | |||
| @@ -14,12 +14,15 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Threading; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class tensorflow : IObjectLife | |||
| { | |||
| protected internal readonly ThreadLocal<Session> _defaultSessionFactory; | |||
| public TF_DataType @byte = TF_DataType.TF_UINT8; | |||
| public TF_DataType @sbyte = TF_DataType.TF_INT8; | |||
| public TF_DataType int16 = TF_DataType.TF_INT16; | |||
| @@ -34,7 +37,13 @@ namespace Tensorflow | |||
| public Context context = new Context(new ContextOptions(), new Status()); | |||
| public Session defaultSession; | |||
| public tensorflow() | |||
| { | |||
| _defaultSessionFactory = new ThreadLocal<Session>(Session); | |||
| } | |||
| public Session defaultSession => _defaultSessionFactory.Value; | |||
| public RefVariable Variable<T>(T data, | |||
| bool trainable = true, | |||
| @@ -0,0 +1,71 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using FluentAssertions; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class MultithreadingTests | |||
| { | |||
| [TestMethod] | |||
| public void SessionCreation() | |||
| { | |||
| tf.Session(); //create one to increase next id to 1. | |||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||
| //the core method | |||
| void Core(int tid) | |||
| { | |||
| tf.peak_default_graph().Should().BeNull(); | |||
| //tf.Session created an other graph | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var default_graph = tf.peak_default_graph(); | |||
| var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||
| sess_graph.Should().NotBeNull(); | |||
| default_graph.Should().NotBeNull() | |||
| .And.BeEquivalentTo(sess_graph); | |||
| } | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void GraphCreation() | |||
| { | |||
| tf.Graph(); //create one to increase next id to 1. | |||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||
| //the core method | |||
| void Core(int tid) | |||
| { | |||
| tf.peak_default_graph().Should().BeNull(); | |||
| var beforehand = tf.get_default_graph(); //this should create default automatically. | |||
| beforehand.graph_key.Should().NotContain("0", "Already created a graph in an other thread."); | |||
| tf.peak_default_graph().Should().NotBeNull(); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var default_graph = tf.peak_default_graph(); | |||
| var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||
| sess_graph.Should().NotBeNull(); | |||
| default_graph.Should().NotBeNull() | |||
| .And.BeEquivalentTo(sess_graph) | |||
| .And.BeEquivalentTo(beforehand); | |||
| Console.WriteLine($"{tid}-{default_graph.graph_key}"); | |||
| //var result = sess.run(new object[] {g, a}); | |||
| //var actualDeriv = result[0].GetData<float>()[0]; | |||
| //var actual = result[1].GetData<float>()[0]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = c.eval(sess); | |||
| Assert.AreEqual(6, result.Data<double>()[0]); | |||
| Assert.AreEqual(6, result.GetAtIndex<double>(0)); | |||
| } | |||
| } | |||
| } | |||