Browse Source

Graph, Session: defaulting was changed to thread-wide using ThreadLocal<T>

tags/v0.12
Eli Belash 6 years ago
parent
commit
f27b0c892f
6 changed files with 90 additions and 9 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Sessions/Session.cs
  3. +5
    -4
      src/TensorFlowNET.Core/ops.cs
  4. +10
    -1
      src/TensorFlowNET.Core/tensorflow.cs
  5. +71
    -0
      test/TensorFlowNET.UnitTest/MultithreadingTests.cs
  6. +1
    -1
      test/TensorFlowNET.UnitTest/SessionTest.cs

+ 2
- 2
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -38,9 +38,9 @@ namespace Tensorflow


public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) public BaseSession(string target = "", Graph g = null, SessionOptions opts = null)
{ {
_graph = g is null ? ops.get_default_graph() : g;
_graph = g ?? ops.get_default_graph();
_graph.as_default(); _graph.as_default();
_target = UTF8Encoding.UTF8.GetBytes(target);
_target = Encoding.UTF8.GetBytes(target);


SessionOptions newOpts = opts ?? new SessionOptions(); SessionOptions newOpts = opts ?? new SessionOptions();




+ 1
- 1
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow


public Session as_default() public Session as_default()
{ {
tf.defaultSession = this;
tf._defaultSessionFactory.Value = this;
return this; return this;
} }




+ 5
- 4
src/TensorFlowNET.Core/ops.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using Google.Protobuf; using Google.Protobuf;
using System.Linq; using System.Linq;
using System.Threading;
using NumSharp; using NumSharp;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -26,6 +27,10 @@ namespace Tensorflow
{ {
public partial class ops public partial class ops
{ {
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack());

public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value;

public static int tensor_id(Tensor tensor) public static int tensor_id(Tensor tensor)
{ {
return tensor.Id; return tensor.Id;
@@ -72,8 +77,6 @@ namespace Tensorflow
return get_default_graph().get_collection_ref(key); return get_default_graph().get_collection_ref(key);
} }


public static DefaultGraphStack default_graph_stack = new DefaultGraphStack();

/// <summary> /// <summary>
/// Returns the default graph for the current thread. /// Returns the default graph for the current thread.
/// ///
@@ -387,8 +390,6 @@ namespace Tensorflow
/// <returns>The default `Session` being used in the current thread.</returns> /// <returns>The default `Session` being used in the current thread.</returns>
public static Session get_default_session() public static Session get_default_session()
{ {
if (tf.defaultSession == null)
tf.defaultSession = tf.Session();
return tf.defaultSession; return tf.defaultSession;
} }




+ 10
- 1
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -14,12 +14,15 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Threading;
using Tensorflow.Eager; using Tensorflow.Eager;


namespace Tensorflow namespace Tensorflow
{ {
public partial class tensorflow : IObjectLife public partial class tensorflow : IObjectLife
{ {
protected internal readonly ThreadLocal<Session> _defaultSessionFactory;

public TF_DataType @byte = TF_DataType.TF_UINT8; public TF_DataType @byte = TF_DataType.TF_UINT8;
public TF_DataType @sbyte = TF_DataType.TF_INT8; public TF_DataType @sbyte = TF_DataType.TF_INT8;
public TF_DataType int16 = TF_DataType.TF_INT16; public TF_DataType int16 = TF_DataType.TF_INT16;
@@ -34,7 +37,13 @@ namespace Tensorflow


public Context context = new Context(new ContextOptions(), new Status()); public Context context = new Context(new ContextOptions(), new Status());


public Session defaultSession;

public tensorflow()
{
_defaultSessionFactory = new ThreadLocal<Session>(Session);
}

public Session defaultSession => _defaultSessionFactory.Value;


public RefVariable Variable<T>(T data, public RefVariable Variable<T>(T data,
bool trainable = true, bool trainable = true,


+ 71
- 0
test/TensorFlowNET.UnitTest/MultithreadingTests.cs View File

@@ -0,0 +1,71 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class MultithreadingTests
{
[TestMethod]
public void SessionCreation()
{
tf.Session(); //create one to increase next id to 1.

MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();

//tf.Session created an other graph
using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.GetPrivate<Graph>("_graph");
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
.And.BeEquivalentTo(sess_graph);
}
}
}

[TestMethod]
public void GraphCreation()
{
tf.Graph(); //create one to increase next id to 1.

MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
var beforehand = tf.get_default_graph(); //this should create default automatically.
beforehand.graph_key.Should().NotContain("0", "Already created a graph in an other thread.");
tf.peak_default_graph().Should().NotBeNull();

using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.GetPrivate<Graph>("_graph");
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
.And.BeEquivalentTo(sess_graph)
.And.BeEquivalentTo(beforehand);

Console.WriteLine($"{tid}-{default_graph.graph_key}");

//var result = sess.run(new object[] {g, a});
//var actualDeriv = result[0].GetData<float>()[0];
//var actual = result[1].GetData<float>()[0];
}
}
}
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session()) using (var sess = tf.Session())
{ {
var result = c.eval(sess); var result = c.eval(sess);
Assert.AreEqual(6, result.Data<double>()[0]);
Assert.AreEqual(6, result.GetAtIndex<double>(0));
} }
} }
} }


Loading…
Cancel
Save