| @@ -20,8 +20,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| public graph_util_impl graph_util => new graph_util_impl(); | |||||
| public GraphTransformer graph_transforms => new GraphTransformer(); | |||||
| public graph_util_impl graph_util { get; } = new graph_util_impl(); | |||||
| public GraphTransformer graph_transforms { get; } = new GraphTransformer(); | |||||
| public GraphKeys GraphKeys { get; } = new GraphKeys(); | public GraphKeys GraphKeys { get; } = new GraphKeys(); | ||||
| public void reset_default_graph() | public void reset_default_graph() | ||||
| @@ -171,7 +171,7 @@ namespace Tensorflow.Contexts | |||||
| public void reset_context() | public void reset_context() | ||||
| { | { | ||||
| ops.reset_uid(); | |||||
| // ops.reset_uid(); | |||||
| // tf.defaultSession = null; | // tf.defaultSession = null; | ||||
| ops.reset_default_graph(); | ops.reset_default_graph(); | ||||
| context_switches.Clear(); | context_switches.Clear(); | ||||
| @@ -14,9 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -25,19 +24,14 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class DefaultGraphStack | public class DefaultGraphStack | ||||
| { | { | ||||
| private readonly Stack<Graph> _stack = new Stack<Graph>(); | |||||
| Graph _global_default_graph; | |||||
| Stack<Graph> _stack = new Stack<Graph>(); | |||||
| public Graph get_default() | public Graph get_default() | ||||
| { | { | ||||
| if (_stack.Count > 0) | |||||
| return _stack.Peek(); | |||||
| else if (_global_default_graph != null) | |||||
| return _global_default_graph; | |||||
| else | |||||
| _global_default_graph = new Graph(); | |||||
| if (_stack.Count == 0) | |||||
| _stack.Push(new Graph()); | |||||
| return _global_default_graph; | |||||
| return _stack.Peek(); | |||||
| } | } | ||||
| public Graph get_controller(Graph g) | public Graph get_controller(Graph g) | ||||
| @@ -61,7 +55,6 @@ namespace Tensorflow | |||||
| public void reset() | public void reset() | ||||
| { | { | ||||
| _stack.Clear(); | _stack.Clear(); | ||||
| _global_default_graph = null; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -107,7 +107,7 @@ namespace Tensorflow.NumPy | |||||
| if (tensor.Handle == null) | if (tensor.Handle == null) | ||||
| { | { | ||||
| if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
| tensor = tf.defaultSession.eval(tensor); | |||||
| tensor = tf.get_default_session().eval(tensor); | |||||
| } | } | ||||
| return new NDArray(tensor, tf.executing_eagerly()); | return new NDArray(tensor, tf.executing_eagerly()); | ||||
| @@ -38,7 +38,7 @@ namespace Tensorflow.NumPy | |||||
| { | { | ||||
| if (_handle is null) | if (_handle is null) | ||||
| { | { | ||||
| tensor = tf.defaultSession.eval(tensor); | |||||
| tensor = tf.get_default_session().eval(tensor); | |||||
| _handle = tensor.Handle; | _handle = tensor.Handle; | ||||
| } | } | ||||
| @@ -23,7 +23,7 @@ namespace Tensorflow.Variables | |||||
| { | { | ||||
| // gen_resource_variable_ops.destroy_resource_op(_tensor, ignore_lookup_error: true); | // gen_resource_variable_ops.destroy_resource_op(_tensor, ignore_lookup_error: true); | ||||
| tf.device(_handle_device); | |||||
| // tf.device(_handle_device); | |||||
| tf.Runner.TFE_Execute(tf.Context, _handle_device, "DestroyResourceOp", | tf.Runner.TFE_Execute(tf.Context, _handle_device, "DestroyResourceOp", | ||||
| new[] { _tensor }, | new[] { _tensor }, | ||||
| new object[] { "ignore_lookup_error", true }, 0); | new object[] { "ignore_lookup_error", true }, 0); | ||||
| @@ -1,70 +1,15 @@ | |||||
| using System.Threading; | |||||
| using System; | |||||
| using System.Threading; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class ops | public partial class ops | ||||
| { | { | ||||
| private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||||
| private static volatile Session _singleSesson; | |||||
| private static volatile DefaultGraphStack _singleGraphStack; | |||||
| private static readonly object _threadingLock = new object(); | |||||
| public static DefaultGraphStack default_graph_stack | |||||
| { | |||||
| get | |||||
| { | |||||
| if (!isSingleThreaded) | |||||
| return _defaultGraphFactory.Value; | |||||
| if (_singleGraphStack == null) | |||||
| { | |||||
| lock (_threadingLock) | |||||
| { | |||||
| if (_singleGraphStack == null) | |||||
| _singleGraphStack = new DefaultGraphStack(); | |||||
| } | |||||
| } | |||||
| return _singleGraphStack; | |||||
| } | |||||
| } | |||||
| private static bool isSingleThreaded = false; | |||||
| /// <summary> | |||||
| /// Does this library ignore different thread accessing. | |||||
| /// </summary> | |||||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks> | |||||
| public static bool IsSingleThreaded | |||||
| { | |||||
| get => isSingleThreaded; | |||||
| set | |||||
| { | |||||
| if (value) | |||||
| enforce_singlethreading(); | |||||
| else | |||||
| enforce_multithreading(); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Forces the library to ignore different thread accessing. | |||||
| /// </summary> | |||||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks> | |||||
| public static void enforce_singlethreading() | |||||
| { | |||||
| isSingleThreaded = true; | |||||
| } | |||||
| /// <summary> | |||||
| /// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing. | |||||
| /// </summary> | |||||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks> | |||||
| public static void enforce_multithreading() | |||||
| { | |||||
| isSingleThreaded = false; | |||||
| } | |||||
| [ThreadStatic] | |||||
| static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); | |||||
| [ThreadStatic] | |||||
| static Session defaultSession; | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the default session for the current thread. | /// Returns the default session for the current thread. | ||||
| @@ -72,19 +17,10 @@ namespace Tensorflow | |||||
| /// <returns>The default `Session` being used in the current thread.</returns> | /// <returns>The default `Session` being used in the current thread.</returns> | ||||
| public static Session get_default_session() | public static Session get_default_session() | ||||
| { | { | ||||
| if (!isSingleThreaded) | |||||
| return tf.defaultSession; | |||||
| if (defaultSession == null) | |||||
| defaultSession = new Session(tf.get_default_graph()); | |||||
| if (_singleSesson == null) | |||||
| { | |||||
| lock (_threadingLock) | |||||
| { | |||||
| if (_singleSesson == null) | |||||
| _singleSesson = new Session(); | |||||
| } | |||||
| } | |||||
| return _singleSesson; | |||||
| return defaultSession; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -93,15 +29,8 @@ namespace Tensorflow | |||||
| /// <returns>The default `Session` being used in the current thread.</returns> | /// <returns>The default `Session` being used in the current thread.</returns> | ||||
| public static Session set_default_session(Session sess) | public static Session set_default_session(Session sess) | ||||
| { | { | ||||
| if (!isSingleThreaded) | |||||
| return tf.defaultSession = sess; | |||||
| lock (_threadingLock) | |||||
| { | |||||
| _singleSesson = sess; | |||||
| } | |||||
| return _singleSesson; | |||||
| defaultSession = sess; | |||||
| return sess; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -118,10 +47,18 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Graph get_default_graph() | public static Graph get_default_graph() | ||||
| => default_graph_stack.get_default(); | |||||
| { | |||||
| if (default_graph_stack == null) | |||||
| default_graph_stack = new DefaultGraphStack(); | |||||
| return default_graph_stack.get_default(); | |||||
| } | |||||
| public static Graph set_default_graph(Graph g) | public static Graph set_default_graph(Graph g) | ||||
| => default_graph_stack.get_controller(g); | |||||
| { | |||||
| if (default_graph_stack == null) | |||||
| default_graph_stack = new DefaultGraphStack(); | |||||
| return default_graph_stack.get_controller(g); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Clears the default graph stack and resets the global default graph. | /// Clears the default graph stack and resets the global default graph. | ||||
| @@ -135,6 +72,8 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static void reset_default_graph() | public static void reset_default_graph() | ||||
| { | { | ||||
| if (default_graph_stack == null) | |||||
| return; | |||||
| //if (!_default_graph_stack.is_cleared()) | //if (!_default_graph_stack.is_cleared()) | ||||
| // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | ||||
| // "nested graphs. If you need a cleared graph, " + | // "nested graphs. If you need a cleared graph, " + | ||||
| @@ -143,7 +82,11 @@ namespace Tensorflow | |||||
| } | } | ||||
| public static Graph peak_default_graph() | public static Graph peak_default_graph() | ||||
| => default_graph_stack.peak_controller(); | |||||
| { | |||||
| if (default_graph_stack == null) | |||||
| default_graph_stack = new DefaultGraphStack(); | |||||
| return default_graph_stack.peak_controller(); | |||||
| } | |||||
| public static void pop_graph() | public static void pop_graph() | ||||
| => default_graph_stack.pop(); | => default_graph_stack.pop(); | ||||
| @@ -16,6 +16,7 @@ | |||||
| using Serilog; | using Serilog; | ||||
| using Serilog.Core; | using Serilog.Core; | ||||
| using System.Threading; | |||||
| using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Gradients; | using Tensorflow.Gradients; | ||||
| @@ -38,12 +39,18 @@ namespace Tensorflow | |||||
| public TF_DataType chars = TF_DataType.TF_STRING; | public TF_DataType chars = TF_DataType.TF_STRING; | ||||
| public TF_DataType @string = TF_DataType.TF_STRING; | public TF_DataType @string = TF_DataType.TF_STRING; | ||||
| public Status Status; | |||||
| public OpDefLibrary OpDefLib; | public OpDefLibrary OpDefLib; | ||||
| public Context Context; | |||||
| public IEagerRunner Runner; | |||||
| public Logger Logger; | public Logger Logger; | ||||
| ThreadLocal<Status> _status = new ThreadLocal<Status>(() => new Status()); | |||||
| public Status Status => _status.Value; | |||||
| ThreadLocal<Context> _context = new ThreadLocal<Context>(() => new Context()); | |||||
| public Context Context => _context.Value; | |||||
| ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner()); | |||||
| public IEagerRunner Runner => _runner.Value; | |||||
| public tensorflow() | public tensorflow() | ||||
| { | { | ||||
| Logger = new LoggerConfiguration() | Logger = new LoggerConfiguration() | ||||
| @@ -51,12 +58,8 @@ namespace Tensorflow | |||||
| .WriteTo.Console() | .WriteTo.Console() | ||||
| .CreateLogger(); | .CreateLogger(); | ||||
| Status = new Status(); | |||||
| Context = new Context(); | |||||
| OpDefLib = new OpDefLibrary(); | OpDefLib = new OpDefLibrary(); | ||||
| ConstructThreadingObjects(); | |||||
| InitGradientEnvironment(); | InitGradientEnvironment(); | ||||
| Runner = new EagerRunner(); | |||||
| } | } | ||||
| public string VERSION => c_api.StringPiece(c_api.TF_Version()); | public string VERSION => c_api.StringPiece(c_api.TF_Version()); | ||||
| @@ -1,53 +0,0 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Threading; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class tensorflow | |||||
| { | |||||
| protected ThreadLocal<Session> defaultSessionFactory; | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public void ConstructThreadingObjects() | |||||
| { | |||||
| defaultSessionFactory = new ThreadLocal<Session>(() => new Session()); | |||||
| } | |||||
| public Session defaultSession | |||||
| { | |||||
| get | |||||
| { | |||||
| if (!ops.IsSingleThreaded) | |||||
| return defaultSessionFactory.Value; | |||||
| return ops.get_default_session(); | |||||
| } | |||||
| internal set | |||||
| { | |||||
| if (!ops.IsSingleThreaded) | |||||
| { | |||||
| defaultSessionFactory.Value = value; | |||||
| return; | |||||
| } | |||||
| ops.set_default_session(value); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -12,6 +12,7 @@ using Tensorflow.Keras.Models; | |||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using System.Threading; | |||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| { | { | ||||
| @@ -9,6 +9,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| using System.Threading; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| @@ -40,24 +41,24 @@ namespace Tensorflow.Keras.Layers | |||||
| return MakOp(inputs); | return MakOp(inputs); | ||||
| } | } | ||||
| ConcreteFunction function; | |||||
| ThreadLocal<ConcreteFunction> function = new ThreadLocal<ConcreteFunction>(); | |||||
| Tensors DeFunCall(Tensors inputs) | Tensors DeFunCall(Tensors inputs) | ||||
| { | { | ||||
| if(function == null) | |||||
| if (function.Value == null) | |||||
| { | { | ||||
| function = new ConcreteFunction(name); | |||||
| function.Enter(); | |||||
| function.Value = new ConcreteFunction(name); | |||||
| function.Value.Enter(); | |||||
| int i = 0; | int i = 0; | ||||
| var graph_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: $"defun_inputs_{i++}")).ToArray(); | var graph_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: $"defun_inputs_{i++}")).ToArray(); | ||||
| var graph_outputs = MakOp(graph_inputs); | var graph_outputs = MakOp(graph_inputs); | ||||
| graph_outputs = mark_as_return(graph_outputs); | graph_outputs = mark_as_return(graph_outputs); | ||||
| function.ToGraph(graph_inputs, graph_outputs); | |||||
| function.Exit(); | |||||
| function.Value.ToGraph(graph_inputs, graph_outputs); | |||||
| function.Value.Exit(); | |||||
| } | } | ||||
| var outputs = function.FilteredCall(inputs); | |||||
| var outputs = function.Value.FilteredCall(inputs); | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| @@ -24,14 +24,12 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| Assert.IsNull(tf.peak_default_graph()); | Assert.IsNull(tf.peak_default_graph()); | ||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var default_graph = tf.get_default_graph(); | |||||
| var sess_graph = sess.graph; | |||||
| Assert.IsNotNull(default_graph); | |||||
| Assert.IsNotNull(sess_graph); | |||||
| Assert.AreEqual(default_graph, sess_graph); | |||||
| } | |||||
| using var sess = tf.Session(); | |||||
| var default_graph = tf.get_default_graph(); | |||||
| var sess_graph = sess.graph; | |||||
| Assert.IsNotNull(default_graph); | |||||
| Assert.IsNotNull(sess_graph); | |||||
| Assert.AreEqual(default_graph, sess_graph); | |||||
| } | } | ||||
| } | } | ||||
| @@ -47,14 +45,12 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| Assert.IsNull(tf.peak_default_graph()); | Assert.IsNull(tf.peak_default_graph()); | ||||
| //tf.Session created an other graph | //tf.Session created an other graph | ||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var default_graph = tf.get_default_graph(); | |||||
| var sess_graph = sess.graph; | |||||
| Assert.IsNotNull(default_graph); | |||||
| Assert.IsNotNull(sess_graph); | |||||
| Assert.AreEqual(default_graph, sess_graph); | |||||
| } | |||||
| using var sess = tf.Session(); | |||||
| var default_graph = tf.get_default_graph(); | |||||
| var sess_graph = sess.graph; | |||||
| Assert.IsNotNull(default_graph); | |||||
| Assert.IsNotNull(sess_graph); | |||||
| Assert.AreEqual(default_graph, sess_graph); | |||||
| } | } | ||||
| } | } | ||||
| @@ -73,20 +69,12 @@ namespace TensorFlowNET.UnitTest | |||||
| beforehand.as_default(); | beforehand.as_default(); | ||||
| Assert.IsNotNull(tf.peak_default_graph()); | Assert.IsNotNull(tf.peak_default_graph()); | ||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var default_graph = tf.peak_default_graph(); | |||||
| var sess_graph = sess.graph; | |||||
| Assert.IsNotNull(default_graph); | |||||
| Assert.IsNotNull(sess_graph); | |||||
| Assert.AreEqual(default_graph, sess_graph); | |||||
| Console.WriteLine($"{tid}-{default_graph.graph_key}"); | |||||
| //var result = sess.run(new object[] {g, a}); | |||||
| //var actualDeriv = result[0].GetData<float>()[0]; | |||||
| //var actual = result[1].GetData<float>()[0]; | |||||
| } | |||||
| using var sess = tf.Session(); | |||||
| var default_graph = tf.peak_default_graph(); | |||||
| var sess_graph = sess.graph; | |||||
| Assert.IsNotNull(default_graph); | |||||
| Assert.IsNotNull(sess_graph); | |||||
| Assert.AreEqual(default_graph, sess_graph); | |||||
| } | } | ||||
| } | } | ||||
| @@ -114,13 +102,10 @@ namespace TensorFlowNET.UnitTest | |||||
| //the core method | //the core method | ||||
| void Core(int tid) | void Core(int tid) | ||||
| { | { | ||||
| using (var sess = tf.Session()) | |||||
| using var sess = tf.Session(); | |||||
| for (int i = 0; i < 100; i++) | |||||
| { | { | ||||
| Tensor t = null; | |||||
| for (int i = 0; i < 100; i++) | |||||
| { | |||||
| t = new Tensor(1); | |||||
| } | |||||
| var t = new Tensor(1); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -134,12 +119,10 @@ namespace TensorFlowNET.UnitTest | |||||
| void Core(int tid) | void Core(int tid) | ||||
| { | { | ||||
| //tf.Session created an other graph | //tf.Session created an other graph | ||||
| using (var sess = tf.Session()) | |||||
| using var sess = tf.Session(); | |||||
| for (int i = 0; i < 100; i++) | |||||
| { | { | ||||
| for (int i = 0; i < 100; i++) | |||||
| { | |||||
| var t = new Tensor(new int[] { 1, 2, 3 }); | |||||
| } | |||||
| var t = new Tensor(new int[] { 1, 2, 3 }); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -147,23 +130,23 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void SessionRun() | public void SessionRun() | ||||
| { | { | ||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| MultiThreadedUnitTestExecuter.Run(2, Core); | |||||
| //the core method | //the core method | ||||
| void Core(int tid) | void Core(int tid) | ||||
| { | { | ||||
| tf.compat.v1.disable_eager_execution(); | |||||
| var graph = tf.Graph().as_default(); | |||||
| //graph is created automatically to perform create these operations | //graph is created automatically to perform create these operations | ||||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | ||||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | ||||
| var math = a1 + a2; | var math = a1 + a2; | ||||
| using var sess = tf.Session(graph); | |||||
| for (int i = 0; i < 100; i++) | for (int i = 0; i < 100; i++) | ||||
| { | { | ||||
| var graph = tf.get_default_graph(); | |||||
| using (var sess = tf.Session(graph)) | |||||
| { | |||||
| var result = sess.run(math); | |||||
| Assert.AreEqual(result[0], 5f); | |||||
| } | |||||
| var result = sess.run(math); | |||||
| Assert.AreEqual(result[0], 5f); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -176,17 +159,18 @@ namespace TensorFlowNET.UnitTest | |||||
| //the core method | //the core method | ||||
| void Core(int tid) | void Core(int tid) | ||||
| { | { | ||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| Assert.IsNotNull(tf.get_default_graph()); | |||||
| //graph is created automatically to perform create these operations | |||||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||||
| var math = a1 + a2; | |||||
| tf.compat.v1.disable_eager_execution(); | |||||
| var graph = tf.Graph().as_default(); | |||||
| var result = sess.run(math); | |||||
| Assert.AreEqual(result[0], 5f); | |||||
| } | |||||
| using var sess = tf.Session(graph); | |||||
| Assert.IsNotNull(tf.get_default_graph()); | |||||
| //graph is created automatically to perform create these operations | |||||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||||
| var math = a1 + a2; | |||||
| var result = sess.run(math); | |||||
| Assert.AreEqual(result[0], 5f); | |||||
| } | } | ||||
| } | } | ||||
| @@ -198,14 +182,12 @@ namespace TensorFlowNET.UnitTest | |||||
| //the core method | //the core method | ||||
| void Core(int tid) | void Core(int tid) | ||||
| { | { | ||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| Assert.IsNotNull(tf.get_default_graph()); | |||||
| //graph is created automatically to perform create these operations | |||||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||||
| var math = a1 + a2; | |||||
| } | |||||
| using var sess = tf.Session(); | |||||
| Assert.IsNotNull(tf.get_default_graph()); | |||||
| //graph is created automatically to perform create these operations | |||||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||||
| var math = a1 + a2; | |||||
| } | } | ||||
| } | } | ||||
| @@ -234,6 +216,10 @@ namespace TensorFlowNET.UnitTest | |||||
| void Core(int tid) | void Core(int tid) | ||||
| { | { | ||||
| Assert.IsNull(tf.peak_default_graph()); | Assert.IsNull(tf.peak_default_graph()); | ||||
| tf.compat.v1.disable_eager_execution(); | |||||
| var graph = tf.Graph().as_default(); | |||||
| //graph is created automatically to perform create these operations | //graph is created automatically to perform create these operations | ||||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | ||||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK"); | var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK"); | ||||
| @@ -248,7 +234,6 @@ namespace TensorFlowNET.UnitTest | |||||
| private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/"); | private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/"); | ||||
| [Ignore] | [Ignore] | ||||
| [TestMethod] | |||||
| public void TF_GraphOperationByName_FromModel() | public void TF_GraphOperationByName_FromModel() | ||||
| { | { | ||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | MultiThreadedUnitTestExecuter.Run(8, Core); | ||||
| @@ -0,0 +1,95 @@ | |||||
| using System; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| using System.Threading.Tasks; | |||||
| using Tensorflow.NumPy; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| namespace TensorFlowNET.Keras.UnitTest | |||||
| { | |||||
| [TestClass] | |||||
| public class MultiThreads | |||||
| { | |||||
| [TestMethod] | |||||
| public void Test1() | |||||
| { | |||||
| //Arrange | |||||
| string savefile = "mymodel.h5"; | |||||
| var model1 = BuildModel(); | |||||
| model1.save_weights(savefile); | |||||
| var model2 = BuildModel(); | |||||
| //act | |||||
| model1.load_weights(savefile); | |||||
| model2.load_weights(savefile); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Test2() | |||||
| { | |||||
| //Arrange | |||||
| string savefile = "mymodel2.h5"; | |||||
| var model1 = BuildModel(); | |||||
| model1.save_weights(savefile); | |||||
| model1 = BuildModel(); //recreate model | |||||
| //act | |||||
| model1.load_weights(savefile); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Test3Multithreading() | |||||
| { | |||||
| //Arrange | |||||
| string savefile = "mymodel3.h5"; | |||||
| var model = BuildModel(); | |||||
| model.save_weights(savefile); | |||||
| //Sanity check without multithreading | |||||
| for (int i = 0; i < 2; i++) | |||||
| { | |||||
| Functional clone = BuildModel(); | |||||
| clone.load_weights(savefile); | |||||
| //Predict something | |||||
| clone.predict(np.array(new float[,] { { 0, 0 } })); | |||||
| } //works | |||||
| //act | |||||
| ParallelOptions parallelOptions = new ParallelOptions(); | |||||
| parallelOptions.MaxDegreeOfParallelism = 1; | |||||
| var input = np.array(new float[,] { { 0, 0 } }); | |||||
| Parallel.For(0, 1, parallelOptions, i => | |||||
| { | |||||
| var clone = BuildModel(); | |||||
| clone.load_weights(savefile); | |||||
| //Predict something | |||||
| clone.predict(input); | |||||
| }); | |||||
| } | |||||
| Functional BuildModel() | |||||
| { | |||||
| tf.Context.reset_context(); | |||||
| var inputs = keras.Input(shape: 2); | |||||
| // 1st dense layer | |||||
| var DenseLayer = keras.layers.Dense(1, activation: keras.activations.Sigmoid); | |||||
| var outputs = DenseLayer.Apply(inputs); | |||||
| // build keras model | |||||
| Functional model = keras.Model(inputs, outputs, name: Guid.NewGuid().ToString()); | |||||
| // show model summary | |||||
| model.summary(); | |||||
| // compile keras model into tensorflow's static graph | |||||
| model.compile(loss: keras.losses.MeanSquaredError(name: Guid.NewGuid().ToString()), | |||||
| optimizer: keras.optimizers.Adam(name: Guid.NewGuid().ToString()), | |||||
| metrics: new[] { "accuracy" }); | |||||
| return model; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -16,7 +16,6 @@ namespace TensorFlowNET.UnitTest | |||||
| /// <summary>Initializes a new instance of the <see cref="T:System.Object" /> class.</summary> | /// <summary>Initializes a new instance of the <see cref="T:System.Object" /> class.</summary> | ||||
| public EnforcedSinglethreadingTests() | public EnforcedSinglethreadingTests() | ||||
| { | { | ||||
| ops.IsSingleThreaded = true; | |||||
| } | } | ||||
| [TestMethod, Ignore("Has to be tested manually.")] | [TestMethod, Ignore("Has to be tested manually.")] | ||||
| @@ -24,8 +23,6 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| lock (_singlethreadLocker) | lock (_singlethreadLocker) | ||||
| { | { | ||||
| ops.IsSingleThreaded.Should().BeTrue(); | |||||
| ops.uid(); //increment id by one | ops.uid(); //increment id by one | ||||
| //the core method | //the core method | ||||