- Separated multithreading related methods to classname.threading.cs partial file - ops: Added enforce_singlethreading(), enforce_multithreading()tags/v0.13
| @@ -37,8 +37,7 @@ namespace Tensorflow | |||||
| public Session as_default() | public Session as_default() | ||||
| { | { | ||||
| tf._defaultSessionFactory.Value = this; | |||||
| return this; | |||||
| return ops.set_default_session(this); | |||||
| } | } | ||||
| [MethodImpl(MethodImplOptions.NoOptimization)] | [MethodImpl(MethodImplOptions.NoOptimization)] | ||||
| @@ -28,10 +28,6 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class ops | public partial class ops | ||||
| { | { | ||||
| private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||||
| public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value; | |||||
| public static int tensor_id(Tensor tensor) | public static int tensor_id(Tensor tensor) | ||||
| { | { | ||||
| return tensor.Id; | return tensor.Id; | ||||
| @@ -78,53 +74,6 @@ namespace Tensorflow | |||||
| return get_default_graph().get_collection_ref<T>(key); | return get_default_graph().get_collection_ref<T>(key); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Returns the default graph for the current thread. | |||||
| /// | |||||
| /// The returned graph will be the innermost graph on which a | |||||
| /// `Graph.as_default()` context has been entered, or a global default | |||||
| /// graph if none has been explicitly created. | |||||
| /// | |||||
| /// NOTE: The default graph is a property of the current thread.If you | |||||
| /// create a new thread, and wish to use the default graph in that | |||||
| /// thread, you must explicitly add a `with g.as_default():` in that | |||||
| /// thread's function. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public static Graph get_default_graph() | |||||
| { | |||||
| //TODO: original source indicates there should be a _default_graph_stack! | |||||
| //return _default_graph_stack.get_default() | |||||
| return default_graph_stack.get_controller(); | |||||
| } | |||||
| public static Graph set_default_graph(Graph graph) | |||||
| { | |||||
| //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | |||||
| default_graph_stack.set_controller(graph); | |||||
| return default_graph_stack.get_controller(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Clears the default graph stack and resets the global default graph. | |||||
| /// | |||||
| /// NOTE: The default graph is a property of the current thread.This | |||||
| /// function applies only to the current thread.Calling this function while | |||||
| /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||||
| /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||||
| /// after calling this function will result in undefined behavior. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public static void reset_default_graph() | |||||
| { | |||||
| //TODO: original source indicates there should be a _default_graph_stack! | |||||
| //if (!_default_graph_stack.is_cleared()) | |||||
| // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||||
| // "nested graphs. If you need a cleared graph, " + | |||||
| // "exit the nesting and create a new graph."); | |||||
| default_graph_stack.reset(); | |||||
| } | |||||
| public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | ||||
| => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | ||||
| @@ -399,15 +348,6 @@ namespace Tensorflow | |||||
| return session.run(tensor, feed_dict); | return session.run(tensor, feed_dict); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Returns the default session for the current thread. | |||||
| /// </summary> | |||||
| /// <returns>The default `Session` being used in the current thread.</returns> | |||||
| public static Session get_default_session() | |||||
| { | |||||
| return tf.defaultSession; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Prepends name scope to a name. | /// Prepends name scope to a name. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,152 @@ | |||||
| using System.Threading; | |||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class ops | |||||
| { | |||||
| private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||||
| private static volatile Session _singleSesson; | |||||
| private static volatile DefaultGraphStack _singleGraphStack; | |||||
| private static readonly object _threadingLock = new object(); | |||||
| public static DefaultGraphStack default_graph_stack | |||||
| { | |||||
| get | |||||
| { | |||||
| if (!isSingleThreaded) | |||||
| return _defaultGraphFactory.Value; | |||||
| if (_singleGraphStack == null) | |||||
| { | |||||
| lock (_threadingLock) | |||||
| { | |||||
| if (_singleGraphStack == null) | |||||
| _singleGraphStack = new DefaultGraphStack(); | |||||
| } | |||||
| } | |||||
| return _singleGraphStack; | |||||
| } | |||||
| } | |||||
| private static bool isSingleThreaded = false; | |||||
| /// <summary> | |||||
| /// Does this library ignore different thread accessing. | |||||
| /// </summary> | |||||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks> | |||||
| public static bool IsSingleThreaded | |||||
| { | |||||
| get => isSingleThreaded; | |||||
| set | |||||
| { | |||||
| if (value) | |||||
| enforce_singlethreading(); | |||||
| else | |||||
| enforce_multithreading(); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Forces the library to ignore different thread accessing. | |||||
| /// </summary> | |||||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks> | |||||
| public static void enforce_singlethreading() | |||||
| { | |||||
| isSingleThreaded = true; | |||||
| } | |||||
| /// <summary> | |||||
| /// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing. | |||||
| /// </summary> | |||||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks> | |||||
| public static void enforce_multithreading() | |||||
| { | |||||
| isSingleThreaded = false; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns the default session for the current thread. | |||||
| /// </summary> | |||||
| /// <returns>The default `Session` being used in the current thread.</returns> | |||||
| public static Session get_default_session() | |||||
| { | |||||
| if (!isSingleThreaded) | |||||
| return tf.defaultSession; | |||||
| if (_singleSesson == null) | |||||
| { | |||||
| lock (_threadingLock) | |||||
| { | |||||
| if (_singleSesson == null) | |||||
| _singleSesson = new Session(); | |||||
| } | |||||
| } | |||||
| return _singleSesson; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns the default session for the current thread. | |||||
| /// </summary> | |||||
| /// <returns>The default `Session` being used in the current thread.</returns> | |||||
| public static Session set_default_session(Session sess) | |||||
| { | |||||
| if (!isSingleThreaded) | |||||
| return tf.defaultSession = sess; | |||||
| lock (_threadingLock) | |||||
| { | |||||
| _singleSesson = sess; | |||||
| } | |||||
| return _singleSesson; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns the default graph for the current thread. | |||||
| /// | |||||
| /// The returned graph will be the innermost graph on which a | |||||
| /// `Graph.as_default()` context has been entered, or a global default | |||||
| /// graph if none has been explicitly created. | |||||
| /// | |||||
| /// NOTE: The default graph is a property of the current thread.If you | |||||
| /// create a new thread, and wish to use the default graph in that | |||||
| /// thread, you must explicitly add a `with g.as_default():` in that | |||||
| /// thread's function. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public static Graph get_default_graph() | |||||
| { | |||||
| //return _default_graph_stack.get_default() | |||||
| return default_graph_stack.get_controller(); | |||||
| } | |||||
| public static Graph set_default_graph(Graph graph) | |||||
| { | |||||
| default_graph_stack.set_controller(graph); | |||||
| return default_graph_stack.get_controller(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Clears the default graph stack and resets the global default graph. | |||||
| /// | |||||
| /// NOTE: The default graph is a property of the current thread.This | |||||
| /// function applies only to the current thread.Calling this function while | |||||
| /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||||
| /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||||
| /// after calling this function will result in undefined behavior. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public static void reset_default_graph() | |||||
| { | |||||
| //if (!_default_graph_stack.is_cleared()) | |||||
| // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||||
| // "nested graphs. If you need a cleared graph, " + | |||||
| // "exit the nesting and create a new graph."); | |||||
| default_graph_stack.reset(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -21,8 +21,6 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class tensorflow : IObjectLife | public partial class tensorflow : IObjectLife | ||||
| { | { | ||||
| protected internal readonly ThreadLocal<Session> _defaultSessionFactory; | |||||
| public TF_DataType @byte = TF_DataType.TF_UINT8; | public TF_DataType @byte = TF_DataType.TF_UINT8; | ||||
| public TF_DataType @sbyte = TF_DataType.TF_INT8; | public TF_DataType @sbyte = TF_DataType.TF_INT8; | ||||
| public TF_DataType int16 = TF_DataType.TF_INT16; | public TF_DataType int16 = TF_DataType.TF_INT16; | ||||
| @@ -40,10 +38,10 @@ namespace Tensorflow | |||||
| public tensorflow() | public tensorflow() | ||||
| { | { | ||||
| _defaultSessionFactory = new ThreadLocal<Session>(() => new Session()); | |||||
| _constructThreadingObjects(); | |||||
| } | } | ||||
| public Session defaultSession => _defaultSessionFactory.Value; | |||||
| public RefVariable Variable<T>(T data, | public RefVariable Variable<T>(T data, | ||||
| bool trainable = true, | bool trainable = true, | ||||
| @@ -0,0 +1,53 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Threading; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class tensorflow : IObjectLife | |||||
| { | |||||
| protected ThreadLocal<Session> _defaultSessionFactory; | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public void _constructThreadingObjects() | |||||
| { | |||||
| _defaultSessionFactory = new ThreadLocal<Session>(() => new Session()); | |||||
| } | |||||
| public Session defaultSession | |||||
| { | |||||
| get | |||||
| { | |||||
| if (!ops.IsSingleThreaded) | |||||
| return _defaultSessionFactory.Value; | |||||
| return ops.get_default_session(); | |||||
| } | |||||
| internal set | |||||
| { | |||||
| if (!ops.IsSingleThreaded) | |||||
| { | |||||
| _defaultSessionFactory.Value = value; | |||||
| return; | |||||
| } | |||||
| ops.set_default_session(value); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,107 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Threading; | |||||
| using FluentAssertions; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| [TestClass] | |||||
| public class EnforcedSinglethreadingTests : CApiTest | |||||
| { | |||||
| private static readonly object _singlethreadLocker = new object(); | |||||
| /// <summary>Initializes a new instance of the <see cref="T:System.Object" /> class.</summary> | |||||
| public EnforcedSinglethreadingTests() | |||||
| { | |||||
| ops.IsSingleThreaded = true; | |||||
| } | |||||
| [TestMethod, Ignore("Has to be tested manually.")] | |||||
| public void SessionCreation() | |||||
| { | |||||
| lock (_singlethreadLocker) | |||||
| { | |||||
| ops.IsSingleThreaded.Should().BeTrue(); | |||||
| ops.uid(); //increment id by one | |||||
| //the core method | |||||
| tf.peak_default_graph().Should().BeNull(); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var default_graph = tf.peak_default_graph(); | |||||
| var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||||
| sess_graph.Should().NotBeNull(); | |||||
| default_graph.Should().NotBeNull() | |||||
| .And.BeEquivalentTo(sess_graph); | |||||
| var (graph, session) = Parallely(() => (tf.get_default_graph(), tf.get_default_session())); | |||||
| graph.Should().BeEquivalentTo(default_graph); | |||||
| session.Should().BeEquivalentTo(sess); | |||||
| } | |||||
| } | |||||
| } | |||||
| T Parallely<T>(Func<T> fnc) | |||||
| { | |||||
| var mrh = new ManualResetEventSlim(); | |||||
| T ret = default; | |||||
| Exception e = default; | |||||
| new Thread(() => | |||||
| { | |||||
| try | |||||
| { | |||||
| ret = fnc(); | |||||
| } catch (Exception ee) | |||||
| { | |||||
| e = ee; | |||||
| throw; | |||||
| } finally | |||||
| { | |||||
| mrh.Set(); | |||||
| } | |||||
| }).Start(); | |||||
| if (!Debugger.IsAttached) | |||||
| mrh.Wait(10000).Should().BeTrue(); | |||||
| else | |||||
| mrh.Wait(-1); | |||||
| e.Should().BeNull(e?.ToString()); | |||||
| return ret; | |||||
| } | |||||
| void Parallely(Action fnc) | |||||
| { | |||||
| var mrh = new ManualResetEventSlim(); | |||||
| Exception e = default; | |||||
| new Thread(() => | |||||
| { | |||||
| try | |||||
| { | |||||
| fnc(); | |||||
| } catch (Exception ee) | |||||
| { | |||||
| e = ee; | |||||
| throw; | |||||
| } finally | |||||
| { | |||||
| mrh.Set(); | |||||
| } | |||||
| }).Start(); | |||||
| mrh.Wait(10000).Should().BeTrue(); | |||||
| e.Should().BeNull(e.ToString()); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -283,14 +283,11 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| } | } | ||||
| private static string modelPath = "./model/"; | |||||
| private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/"); | |||||
| [TestMethod] | [TestMethod] | ||||
| public void TF_GraphOperationByName_FromModel() | public void TF_GraphOperationByName_FromModel() | ||||
| { | { | ||||
| if (!Directory.Exists(modelPath)) | |||||
| return; | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | MultiThreadedUnitTestExecuter.Run(8, Core); | ||||
| //the core method | //the core method | ||||
| @@ -43,6 +43,9 @@ | |||||
| <None Update="model\saved_model.pb"> | <None Update="model\saved_model.pb"> | ||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | ||||
| </None> | </None> | ||||
| <None Update="Utilities\models\example1\saved_model.pb"> | |||||
| <CopyToOutputDirectory>Always</CopyToOutputDirectory> | |||||
| </None> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| </Project> | </Project> | ||||