- Separated multithreading related methods to classname.threading.cs partial file - ops: Added enforce_singlethreading(), enforce_multithreading()tags/v0.20
| @@ -37,8 +37,7 @@ namespace Tensorflow | |||
| public Session as_default() | |||
| { | |||
| tf._defaultSessionFactory.Value = this; | |||
| return this; | |||
| return ops.set_default_session(this); | |||
| } | |||
| [MethodImpl(MethodImplOptions.NoOptimization)] | |||
| @@ -28,10 +28,6 @@ namespace Tensorflow | |||
| { | |||
| public partial class ops | |||
| { | |||
| private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||
| public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value; | |||
| public static int tensor_id(Tensor tensor) | |||
| { | |||
| return tensor.Id; | |||
| @@ -78,53 +74,6 @@ namespace Tensorflow | |||
| return get_default_graph().get_collection_ref<T>(key); | |||
| } | |||
| /// <summary> | |||
| /// Returns the default graph for the current thread. | |||
| /// | |||
| /// The returned graph will be the innermost graph on which a | |||
| /// `Graph.as_default()` context has been entered, or a global default | |||
| /// graph if none has been explicitly created. | |||
| /// | |||
| /// NOTE: The default graph is a property of the current thread.If you | |||
| /// create a new thread, and wish to use the default graph in that | |||
| /// thread, you must explicitly add a `with g.as_default():` in that | |||
| /// thread's function. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static Graph get_default_graph() | |||
| { | |||
| //TODO: original source indicates there should be a _default_graph_stack! | |||
| //return _default_graph_stack.get_default() | |||
| return default_graph_stack.get_controller(); | |||
| } | |||
| public static Graph set_default_graph(Graph graph) | |||
| { | |||
| //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | |||
| default_graph_stack.set_controller(graph); | |||
| return default_graph_stack.get_controller(); | |||
| } | |||
| /// <summary> | |||
| /// Clears the default graph stack and resets the global default graph. | |||
| /// | |||
| /// NOTE: The default graph is a property of the current thread.This | |||
| /// function applies only to the current thread.Calling this function while | |||
| /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||
| /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||
| /// after calling this function will result in undefined behavior. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static void reset_default_graph() | |||
| { | |||
| //TODO: original source indicates there should be a _default_graph_stack! | |||
| //if (!_default_graph_stack.is_cleared()) | |||
| // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||
| // "nested graphs. If you need a cleared graph, " + | |||
| // "exit the nesting and create a new graph."); | |||
| default_graph_stack.reset(); | |||
| } | |||
| public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | |||
| => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | |||
| @@ -399,15 +348,6 @@ namespace Tensorflow | |||
| return session.run(tensor, feed_dict); | |||
| } | |||
| /// <summary> | |||
| /// Returns the default session for the current thread. | |||
| /// </summary> | |||
| /// <returns>The default `Session` being used in the current thread.</returns> | |||
| public static Session get_default_session() | |||
| { | |||
| return tf.defaultSession; | |||
| } | |||
| /// <summary> | |||
| /// Prepends name scope to a name. | |||
| /// </summary> | |||
| @@ -0,0 +1,152 @@ | |||
| using System.Threading; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class ops | |||
| { | |||
| private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||
| private static volatile Session _singleSesson; | |||
| private static volatile DefaultGraphStack _singleGraphStack; | |||
| private static readonly object _threadingLock = new object(); | |||
| public static DefaultGraphStack default_graph_stack | |||
| { | |||
| get | |||
| { | |||
| if (!isSingleThreaded) | |||
| return _defaultGraphFactory.Value; | |||
| if (_singleGraphStack == null) | |||
| { | |||
| lock (_threadingLock) | |||
| { | |||
| if (_singleGraphStack == null) | |||
| _singleGraphStack = new DefaultGraphStack(); | |||
| } | |||
| } | |||
| return _singleGraphStack; | |||
| } | |||
| } | |||
| private static bool isSingleThreaded = false; | |||
| /// <summary> | |||
| /// Does this library ignore different thread accessing. | |||
| /// </summary> | |||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks> | |||
| public static bool IsSingleThreaded | |||
| { | |||
| get => isSingleThreaded; | |||
| set | |||
| { | |||
| if (value) | |||
| enforce_singlethreading(); | |||
| else | |||
| enforce_multithreading(); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Forces the library to ignore different thread accessing. | |||
| /// </summary> | |||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks> | |||
| public static void enforce_singlethreading() | |||
| { | |||
| isSingleThreaded = true; | |||
| } | |||
| /// <summary> | |||
| /// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing. | |||
| /// </summary> | |||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks> | |||
| public static void enforce_multithreading() | |||
| { | |||
| isSingleThreaded = false; | |||
| } | |||
| /// <summary> | |||
| /// Returns the default session for the current thread. | |||
| /// </summary> | |||
| /// <returns>The default `Session` being used in the current thread.</returns> | |||
| public static Session get_default_session() | |||
| { | |||
| if (!isSingleThreaded) | |||
| return tf.defaultSession; | |||
| if (_singleSesson == null) | |||
| { | |||
| lock (_threadingLock) | |||
| { | |||
| if (_singleSesson == null) | |||
| _singleSesson = new Session(); | |||
| } | |||
| } | |||
| return _singleSesson; | |||
| } | |||
| /// <summary> | |||
| /// Returns the default session for the current thread. | |||
| /// </summary> | |||
| /// <returns>The default `Session` being used in the current thread.</returns> | |||
| public static Session set_default_session(Session sess) | |||
| { | |||
| if (!isSingleThreaded) | |||
| return tf.defaultSession = sess; | |||
| lock (_threadingLock) | |||
| { | |||
| _singleSesson = sess; | |||
| } | |||
| return _singleSesson; | |||
| } | |||
| /// <summary> | |||
| /// Returns the default graph for the current thread. | |||
| /// | |||
| /// The returned graph will be the innermost graph on which a | |||
| /// `Graph.as_default()` context has been entered, or a global default | |||
| /// graph if none has been explicitly created. | |||
| /// | |||
| /// NOTE: The default graph is a property of the current thread.If you | |||
| /// create a new thread, and wish to use the default graph in that | |||
| /// thread, you must explicitly add a `with g.as_default():` in that | |||
| /// thread's function. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static Graph get_default_graph() | |||
| { | |||
| //return _default_graph_stack.get_default() | |||
| return default_graph_stack.get_controller(); | |||
| } | |||
| public static Graph set_default_graph(Graph graph) | |||
| { | |||
| default_graph_stack.set_controller(graph); | |||
| return default_graph_stack.get_controller(); | |||
| } | |||
| /// <summary> | |||
| /// Clears the default graph stack and resets the global default graph. | |||
| /// | |||
| /// NOTE: The default graph is a property of the current thread.This | |||
| /// function applies only to the current thread.Calling this function while | |||
| /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||
| /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||
| /// after calling this function will result in undefined behavior. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static void reset_default_graph() | |||
| { | |||
| //if (!_default_graph_stack.is_cleared()) | |||
| // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||
| // "nested graphs. If you need a cleared graph, " + | |||
| // "exit the nesting and create a new graph."); | |||
| default_graph_stack.reset(); | |||
| } | |||
| } | |||
| } | |||
| @@ -21,8 +21,6 @@ namespace Tensorflow | |||
| { | |||
| public partial class tensorflow : IObjectLife | |||
| { | |||
| protected internal readonly ThreadLocal<Session> _defaultSessionFactory; | |||
| public TF_DataType @byte = TF_DataType.TF_UINT8; | |||
| public TF_DataType @sbyte = TF_DataType.TF_INT8; | |||
| public TF_DataType int16 = TF_DataType.TF_INT16; | |||
| @@ -40,10 +38,10 @@ namespace Tensorflow | |||
| public tensorflow() | |||
| { | |||
| _defaultSessionFactory = new ThreadLocal<Session>(() => new Session()); | |||
| _constructThreadingObjects(); | |||
| } | |||
| public Session defaultSession => _defaultSessionFactory.Value; | |||
| public RefVariable Variable<T>(T data, | |||
| bool trainable = true, | |||
| @@ -0,0 +1,53 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Runtime.CompilerServices; | |||
| using System.Threading; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class tensorflow : IObjectLife | |||
| { | |||
| protected ThreadLocal<Session> _defaultSessionFactory; | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public void _constructThreadingObjects() | |||
| { | |||
| _defaultSessionFactory = new ThreadLocal<Session>(() => new Session()); | |||
| } | |||
| public Session defaultSession | |||
| { | |||
| get | |||
| { | |||
| if (!ops.IsSingleThreaded) | |||
| return _defaultSessionFactory.Value; | |||
| return ops.get_default_session(); | |||
| } | |||
| internal set | |||
| { | |||
| if (!ops.IsSingleThreaded) | |||
| { | |||
| _defaultSessionFactory.Value = value; | |||
| return; | |||
| } | |||
| ops.set_default_session(value); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,107 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using System.Threading; | |||
| using FluentAssertions; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using NumSharp; | |||
| using Tensorflow; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class EnforcedSinglethreadingTests : CApiTest | |||
| { | |||
| private static readonly object _singlethreadLocker = new object(); | |||
| /// <summary>Initializes a new instance of the <see cref="T:System.Object" /> class.</summary> | |||
| public EnforcedSinglethreadingTests() | |||
| { | |||
| ops.IsSingleThreaded = true; | |||
| } | |||
| [TestMethod, Ignore("Has to be tested manually.")] | |||
| public void SessionCreation() | |||
| { | |||
| lock (_singlethreadLocker) | |||
| { | |||
| ops.IsSingleThreaded.Should().BeTrue(); | |||
| ops.uid(); //increment id by one | |||
| //the core method | |||
| tf.peak_default_graph().Should().BeNull(); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var default_graph = tf.peak_default_graph(); | |||
| var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||
| sess_graph.Should().NotBeNull(); | |||
| default_graph.Should().NotBeNull() | |||
| .And.BeEquivalentTo(sess_graph); | |||
| var (graph, session) = Parallely(() => (tf.get_default_graph(), tf.get_default_session())); | |||
| graph.Should().BeEquivalentTo(default_graph); | |||
| session.Should().BeEquivalentTo(sess); | |||
| } | |||
| } | |||
| } | |||
| T Parallely<T>(Func<T> fnc) | |||
| { | |||
| var mrh = new ManualResetEventSlim(); | |||
| T ret = default; | |||
| Exception e = default; | |||
| new Thread(() => | |||
| { | |||
| try | |||
| { | |||
| ret = fnc(); | |||
| } catch (Exception ee) | |||
| { | |||
| e = ee; | |||
| throw; | |||
| } finally | |||
| { | |||
| mrh.Set(); | |||
| } | |||
| }).Start(); | |||
| if (!Debugger.IsAttached) | |||
| mrh.Wait(10000).Should().BeTrue(); | |||
| else | |||
| mrh.Wait(-1); | |||
| e.Should().BeNull(e?.ToString()); | |||
| return ret; | |||
| } | |||
| void Parallely(Action fnc) | |||
| { | |||
| var mrh = new ManualResetEventSlim(); | |||
| Exception e = default; | |||
| new Thread(() => | |||
| { | |||
| try | |||
| { | |||
| fnc(); | |||
| } catch (Exception ee) | |||
| { | |||
| e = ee; | |||
| throw; | |||
| } finally | |||
| { | |||
| mrh.Set(); | |||
| } | |||
| }).Start(); | |||
| mrh.Wait(10000).Should().BeTrue(); | |||
| e.Should().BeNull(e.ToString()); | |||
| } | |||
| } | |||
| } | |||
| @@ -283,14 +283,11 @@ namespace TensorFlowNET.UnitTest | |||
| } | |||
| } | |||
| private static string modelPath = "./model/"; | |||
| private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/"); | |||
| [TestMethod] | |||
| public void TF_GraphOperationByName_FromModel() | |||
| { | |||
| if (!Directory.Exists(modelPath)) | |||
| return; | |||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||
| //the core method | |||
| @@ -43,6 +43,9 @@ | |||
| <None Update="model\saved_model.pb"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Utilities\models\example1\saved_model.pb"> | |||
| <CopyToOutputDirectory>Always</CopyToOutputDirectory> | |||
| </None> | |||
| </ItemGroup> | |||
| </Project> | |||