| @@ -20,8 +20,8 @@ namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| { | |||
| public graph_util_impl graph_util => new graph_util_impl(); | |||
| public GraphTransformer graph_transforms => new GraphTransformer(); | |||
| public graph_util_impl graph_util { get; } = new graph_util_impl(); | |||
| public GraphTransformer graph_transforms { get; } = new GraphTransformer(); | |||
| public GraphKeys GraphKeys { get; } = new GraphKeys(); | |||
| public void reset_default_graph() | |||
| @@ -171,7 +171,7 @@ namespace Tensorflow.Contexts | |||
| public void reset_context() | |||
| { | |||
| ops.reset_uid(); | |||
| // ops.reset_uid(); | |||
| // tf.defaultSession = null; | |||
| ops.reset_default_graph(); | |||
| context_switches.Clear(); | |||
| @@ -14,9 +14,8 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -25,19 +24,14 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public class DefaultGraphStack | |||
| { | |||
| private readonly Stack<Graph> _stack = new Stack<Graph>(); | |||
| Graph _global_default_graph; | |||
| Stack<Graph> _stack = new Stack<Graph>(); | |||
| public Graph get_default() | |||
| { | |||
| if (_stack.Count > 0) | |||
| return _stack.Peek(); | |||
| else if (_global_default_graph != null) | |||
| return _global_default_graph; | |||
| else | |||
| _global_default_graph = new Graph(); | |||
| if (_stack.Count == 0) | |||
| _stack.Push(new Graph()); | |||
| return _global_default_graph; | |||
| return _stack.Peek(); | |||
| } | |||
| public Graph get_controller(Graph g) | |||
| @@ -61,7 +55,6 @@ namespace Tensorflow | |||
| public void reset() | |||
| { | |||
| _stack.Clear(); | |||
| _global_default_graph = null; | |||
| } | |||
| } | |||
| } | |||
| @@ -107,7 +107,7 @@ namespace Tensorflow.NumPy | |||
| if (tensor.Handle == null) | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| tensor = tf.defaultSession.eval(tensor); | |||
| tensor = tf.get_default_session().eval(tensor); | |||
| } | |||
| return new NDArray(tensor, tf.executing_eagerly()); | |||
| @@ -38,7 +38,7 @@ namespace Tensorflow.NumPy | |||
| { | |||
| if (_handle is null) | |||
| { | |||
| tensor = tf.defaultSession.eval(tensor); | |||
| tensor = tf.get_default_session().eval(tensor); | |||
| _handle = tensor.Handle; | |||
| } | |||
| @@ -23,7 +23,7 @@ namespace Tensorflow.Variables | |||
| { | |||
| // gen_resource_variable_ops.destroy_resource_op(_tensor, ignore_lookup_error: true); | |||
| tf.device(_handle_device); | |||
| // tf.device(_handle_device); | |||
| tf.Runner.TFE_Execute(tf.Context, _handle_device, "DestroyResourceOp", | |||
| new[] { _tensor }, | |||
| new object[] { "ignore_lookup_error", true }, 0); | |||
| @@ -1,70 +1,15 @@ | |||
| using System.Threading; | |||
| using System; | |||
| using System.Threading; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class ops | |||
| { | |||
| private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack()); | |||
| private static volatile Session _singleSesson; | |||
| private static volatile DefaultGraphStack _singleGraphStack; | |||
| private static readonly object _threadingLock = new object(); | |||
| public static DefaultGraphStack default_graph_stack | |||
| { | |||
| get | |||
| { | |||
| if (!isSingleThreaded) | |||
| return _defaultGraphFactory.Value; | |||
| if (_singleGraphStack == null) | |||
| { | |||
| lock (_threadingLock) | |||
| { | |||
| if (_singleGraphStack == null) | |||
| _singleGraphStack = new DefaultGraphStack(); | |||
| } | |||
| } | |||
| return _singleGraphStack; | |||
| } | |||
| } | |||
| private static bool isSingleThreaded = false; | |||
| /// <summary> | |||
| /// Does this library ignore different thread accessing. | |||
| /// </summary> | |||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks> | |||
| public static bool IsSingleThreaded | |||
| { | |||
| get => isSingleThreaded; | |||
| set | |||
| { | |||
| if (value) | |||
| enforce_singlethreading(); | |||
| else | |||
| enforce_multithreading(); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Forces the library to ignore different thread accessing. | |||
| /// </summary> | |||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks> | |||
| public static void enforce_singlethreading() | |||
| { | |||
| isSingleThreaded = true; | |||
| } | |||
| /// <summary> | |||
| /// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing. | |||
| /// </summary> | |||
| /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks> | |||
| public static void enforce_multithreading() | |||
| { | |||
| isSingleThreaded = false; | |||
| } | |||
| [ThreadStatic] | |||
| static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); | |||
| [ThreadStatic] | |||
| static Session defaultSession; | |||
| /// <summary> | |||
| /// Returns the default session for the current thread. | |||
| @@ -72,19 +17,10 @@ namespace Tensorflow | |||
| /// <returns>The default `Session` being used in the current thread.</returns> | |||
| public static Session get_default_session() | |||
| { | |||
| if (!isSingleThreaded) | |||
| return tf.defaultSession; | |||
| if (defaultSession == null) | |||
| defaultSession = new Session(tf.get_default_graph()); | |||
| if (_singleSesson == null) | |||
| { | |||
| lock (_threadingLock) | |||
| { | |||
| if (_singleSesson == null) | |||
| _singleSesson = new Session(); | |||
| } | |||
| } | |||
| return _singleSesson; | |||
| return defaultSession; | |||
| } | |||
| /// <summary> | |||
| @@ -93,15 +29,8 @@ namespace Tensorflow | |||
| /// <returns>The default `Session` being used in the current thread.</returns> | |||
| public static Session set_default_session(Session sess) | |||
| { | |||
| if (!isSingleThreaded) | |||
| return tf.defaultSession = sess; | |||
| lock (_threadingLock) | |||
| { | |||
| _singleSesson = sess; | |||
| } | |||
| return _singleSesson; | |||
| defaultSession = sess; | |||
| return sess; | |||
| } | |||
| /// <summary> | |||
| @@ -118,10 +47,18 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static Graph get_default_graph() | |||
| => default_graph_stack.get_default(); | |||
| { | |||
| if (default_graph_stack == null) | |||
| default_graph_stack = new DefaultGraphStack(); | |||
| return default_graph_stack.get_default(); | |||
| } | |||
| public static Graph set_default_graph(Graph g) | |||
| => default_graph_stack.get_controller(g); | |||
| { | |||
| if (default_graph_stack == null) | |||
| default_graph_stack = new DefaultGraphStack(); | |||
| return default_graph_stack.get_controller(g); | |||
| } | |||
| /// <summary> | |||
| /// Clears the default graph stack and resets the global default graph. | |||
| @@ -135,6 +72,8 @@ namespace Tensorflow | |||
| /// <returns></returns> | |||
| public static void reset_default_graph() | |||
| { | |||
| if (default_graph_stack == null) | |||
| return; | |||
| //if (!_default_graph_stack.is_cleared()) | |||
| // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||
| // "nested graphs. If you need a cleared graph, " + | |||
| @@ -143,7 +82,11 @@ namespace Tensorflow | |||
| } | |||
| public static Graph peak_default_graph() | |||
| => default_graph_stack.peak_controller(); | |||
| { | |||
| if (default_graph_stack == null) | |||
| default_graph_stack = new DefaultGraphStack(); | |||
| return default_graph_stack.peak_controller(); | |||
| } | |||
| public static void pop_graph() | |||
| => default_graph_stack.pop(); | |||
| @@ -16,6 +16,7 @@ | |||
| using Serilog; | |||
| using Serilog.Core; | |||
| using System.Threading; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Gradients; | |||
| @@ -38,12 +39,18 @@ namespace Tensorflow | |||
| public TF_DataType chars = TF_DataType.TF_STRING; | |||
| public TF_DataType @string = TF_DataType.TF_STRING; | |||
| public Status Status; | |||
| public OpDefLibrary OpDefLib; | |||
| public Context Context; | |||
| public IEagerRunner Runner; | |||
| public Logger Logger; | |||
| ThreadLocal<Status> _status = new ThreadLocal<Status>(() => new Status()); | |||
| public Status Status => _status.Value; | |||
| ThreadLocal<Context> _context = new ThreadLocal<Context>(() => new Context()); | |||
| public Context Context => _context.Value; | |||
| ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner()); | |||
| public IEagerRunner Runner => _runner.Value; | |||
| public tensorflow() | |||
| { | |||
| Logger = new LoggerConfiguration() | |||
| @@ -51,12 +58,8 @@ namespace Tensorflow | |||
| .WriteTo.Console() | |||
| .CreateLogger(); | |||
| Status = new Status(); | |||
| Context = new Context(); | |||
| OpDefLib = new OpDefLibrary(); | |||
| ConstructThreadingObjects(); | |||
| InitGradientEnvironment(); | |||
| Runner = new EagerRunner(); | |||
| } | |||
| public string VERSION => c_api.StringPiece(c_api.TF_Version()); | |||
| @@ -1,53 +0,0 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Runtime.CompilerServices; | |||
| using System.Threading; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| { | |||
| protected ThreadLocal<Session> defaultSessionFactory; | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public void ConstructThreadingObjects() | |||
| { | |||
| defaultSessionFactory = new ThreadLocal<Session>(() => new Session()); | |||
| } | |||
| public Session defaultSession | |||
| { | |||
| get | |||
| { | |||
| if (!ops.IsSingleThreaded) | |||
| return defaultSessionFactory.Value; | |||
| return ops.get_default_session(); | |||
| } | |||
| internal set | |||
| { | |||
| if (!ops.IsSingleThreaded) | |||
| { | |||
| defaultSessionFactory.Value = value; | |||
| return; | |||
| } | |||
| ops.set_default_session(value); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -12,6 +12,7 @@ using Tensorflow.Keras.Models; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Keras.Utils; | |||
| using System.Threading; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| @@ -9,6 +9,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Functions; | |||
| using System.Threading; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -40,24 +41,24 @@ namespace Tensorflow.Keras.Layers | |||
| return MakOp(inputs); | |||
| } | |||
| ConcreteFunction function; | |||
| ThreadLocal<ConcreteFunction> function = new ThreadLocal<ConcreteFunction>(); | |||
| Tensors DeFunCall(Tensors inputs) | |||
| { | |||
| if(function == null) | |||
| if (function.Value == null) | |||
| { | |||
| function = new ConcreteFunction(name); | |||
| function.Enter(); | |||
| function.Value = new ConcreteFunction(name); | |||
| function.Value.Enter(); | |||
| int i = 0; | |||
| var graph_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: $"defun_inputs_{i++}")).ToArray(); | |||
| var graph_outputs = MakOp(graph_inputs); | |||
| graph_outputs = mark_as_return(graph_outputs); | |||
| function.ToGraph(graph_inputs, graph_outputs); | |||
| function.Exit(); | |||
| function.Value.ToGraph(graph_inputs, graph_outputs); | |||
| function.Value.Exit(); | |||
| } | |||
| var outputs = function.FilteredCall(inputs); | |||
| var outputs = function.Value.FilteredCall(inputs); | |||
| return outputs; | |||
| } | |||
| @@ -24,14 +24,12 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| Assert.IsNull(tf.peak_default_graph()); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var default_graph = tf.get_default_graph(); | |||
| var sess_graph = sess.graph; | |||
| Assert.IsNotNull(default_graph); | |||
| Assert.IsNotNull(sess_graph); | |||
| Assert.AreEqual(default_graph, sess_graph); | |||
| } | |||
| using var sess = tf.Session(); | |||
| var default_graph = tf.get_default_graph(); | |||
| var sess_graph = sess.graph; | |||
| Assert.IsNotNull(default_graph); | |||
| Assert.IsNotNull(sess_graph); | |||
| Assert.AreEqual(default_graph, sess_graph); | |||
| } | |||
| } | |||
| @@ -47,14 +45,12 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| Assert.IsNull(tf.peak_default_graph()); | |||
| //tf.Session created an other graph | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var default_graph = tf.get_default_graph(); | |||
| var sess_graph = sess.graph; | |||
| Assert.IsNotNull(default_graph); | |||
| Assert.IsNotNull(sess_graph); | |||
| Assert.AreEqual(default_graph, sess_graph); | |||
| } | |||
| using var sess = tf.Session(); | |||
| var default_graph = tf.get_default_graph(); | |||
| var sess_graph = sess.graph; | |||
| Assert.IsNotNull(default_graph); | |||
| Assert.IsNotNull(sess_graph); | |||
| Assert.AreEqual(default_graph, sess_graph); | |||
| } | |||
| } | |||
| @@ -73,20 +69,12 @@ namespace TensorFlowNET.UnitTest | |||
| beforehand.as_default(); | |||
| Assert.IsNotNull(tf.peak_default_graph()); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var default_graph = tf.peak_default_graph(); | |||
| var sess_graph = sess.graph; | |||
| Assert.IsNotNull(default_graph); | |||
| Assert.IsNotNull(sess_graph); | |||
| Assert.AreEqual(default_graph, sess_graph); | |||
| Console.WriteLine($"{tid}-{default_graph.graph_key}"); | |||
| //var result = sess.run(new object[] {g, a}); | |||
| //var actualDeriv = result[0].GetData<float>()[0]; | |||
| //var actual = result[1].GetData<float>()[0]; | |||
| } | |||
| using var sess = tf.Session(); | |||
| var default_graph = tf.peak_default_graph(); | |||
| var sess_graph = sess.graph; | |||
| Assert.IsNotNull(default_graph); | |||
| Assert.IsNotNull(sess_graph); | |||
| Assert.AreEqual(default_graph, sess_graph); | |||
| } | |||
| } | |||
| @@ -114,13 +102,10 @@ namespace TensorFlowNET.UnitTest | |||
| //the core method | |||
| void Core(int tid) | |||
| { | |||
| using (var sess = tf.Session()) | |||
| using var sess = tf.Session(); | |||
| for (int i = 0; i < 100; i++) | |||
| { | |||
| Tensor t = null; | |||
| for (int i = 0; i < 100; i++) | |||
| { | |||
| t = new Tensor(1); | |||
| } | |||
| var t = new Tensor(1); | |||
| } | |||
| } | |||
| } | |||
| @@ -134,12 +119,10 @@ namespace TensorFlowNET.UnitTest | |||
| void Core(int tid) | |||
| { | |||
| //tf.Session created an other graph | |||
| using (var sess = tf.Session()) | |||
| using var sess = tf.Session(); | |||
| for (int i = 0; i < 100; i++) | |||
| { | |||
| for (int i = 0; i < 100; i++) | |||
| { | |||
| var t = new Tensor(new int[] { 1, 2, 3 }); | |||
| } | |||
| var t = new Tensor(new int[] { 1, 2, 3 }); | |||
| } | |||
| } | |||
| } | |||
| @@ -147,23 +130,23 @@ namespace TensorFlowNET.UnitTest | |||
| [TestMethod] | |||
| public void SessionRun() | |||
| { | |||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||
| MultiThreadedUnitTestExecuter.Run(2, Core); | |||
| //the core method | |||
| void Core(int tid) | |||
| { | |||
| tf.compat.v1.disable_eager_execution(); | |||
| var graph = tf.Graph().as_default(); | |||
| //graph is created automatically to perform create these operations | |||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
| var math = a1 + a2; | |||
| using var sess = tf.Session(graph); | |||
| for (int i = 0; i < 100; i++) | |||
| { | |||
| var graph = tf.get_default_graph(); | |||
| using (var sess = tf.Session(graph)) | |||
| { | |||
| var result = sess.run(math); | |||
| Assert.AreEqual(result[0], 5f); | |||
| } | |||
| var result = sess.run(math); | |||
| Assert.AreEqual(result[0], 5f); | |||
| } | |||
| } | |||
| } | |||
| @@ -176,17 +159,18 @@ namespace TensorFlowNET.UnitTest | |||
| //the core method | |||
| void Core(int tid) | |||
| { | |||
| using (var sess = tf.Session()) | |||
| { | |||
| Assert.IsNotNull(tf.get_default_graph()); | |||
| //graph is created automatically to perform create these operations | |||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
| var math = a1 + a2; | |||
| tf.compat.v1.disable_eager_execution(); | |||
| var graph = tf.Graph().as_default(); | |||
| var result = sess.run(math); | |||
| Assert.AreEqual(result[0], 5f); | |||
| } | |||
| using var sess = tf.Session(graph); | |||
| Assert.IsNotNull(tf.get_default_graph()); | |||
| //graph is created automatically to perform create these operations | |||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
| var math = a1 + a2; | |||
| var result = sess.run(math); | |||
| Assert.AreEqual(result[0], 5f); | |||
| } | |||
| } | |||
| @@ -198,14 +182,12 @@ namespace TensorFlowNET.UnitTest | |||
| //the core method | |||
| void Core(int tid) | |||
| { | |||
| using (var sess = tf.Session()) | |||
| { | |||
| Assert.IsNotNull(tf.get_default_graph()); | |||
| //graph is created automatically to perform create these operations | |||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
| var math = a1 + a2; | |||
| } | |||
| using var sess = tf.Session(); | |||
| Assert.IsNotNull(tf.get_default_graph()); | |||
| //graph is created automatically to perform create these operations | |||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
| var math = a1 + a2; | |||
| } | |||
| } | |||
| @@ -234,6 +216,10 @@ namespace TensorFlowNET.UnitTest | |||
| void Core(int tid) | |||
| { | |||
| Assert.IsNull(tf.peak_default_graph()); | |||
| tf.compat.v1.disable_eager_execution(); | |||
| var graph = tf.Graph().as_default(); | |||
| //graph is created automatically to perform create these operations | |||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK"); | |||
| @@ -248,7 +234,6 @@ namespace TensorFlowNET.UnitTest | |||
| private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/"); | |||
| [Ignore] | |||
| [TestMethod] | |||
| public void TF_GraphOperationByName_FromModel() | |||
| { | |||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||
| @@ -0,0 +1,95 @@ | |||
| using System; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| using System.Threading.Tasks; | |||
| using Tensorflow.NumPy; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| namespace TensorFlowNET.Keras.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class MultiThreads | |||
| { | |||
| [TestMethod] | |||
| public void Test1() | |||
| { | |||
| //Arrange | |||
| string savefile = "mymodel.h5"; | |||
| var model1 = BuildModel(); | |||
| model1.save_weights(savefile); | |||
| var model2 = BuildModel(); | |||
| //act | |||
| model1.load_weights(savefile); | |||
| model2.load_weights(savefile); | |||
| } | |||
| [TestMethod] | |||
| public void Test2() | |||
| { | |||
| //Arrange | |||
| string savefile = "mymodel2.h5"; | |||
| var model1 = BuildModel(); | |||
| model1.save_weights(savefile); | |||
| model1 = BuildModel(); //recreate model | |||
| //act | |||
| model1.load_weights(savefile); | |||
| } | |||
| [TestMethod] | |||
| public void Test3Multithreading() | |||
| { | |||
| //Arrange | |||
| string savefile = "mymodel3.h5"; | |||
| var model = BuildModel(); | |||
| model.save_weights(savefile); | |||
| //Sanity check without multithreading | |||
| for (int i = 0; i < 2; i++) | |||
| { | |||
| Functional clone = BuildModel(); | |||
| clone.load_weights(savefile); | |||
| //Predict something | |||
| clone.predict(np.array(new float[,] { { 0, 0 } })); | |||
| } //works | |||
| //act | |||
| ParallelOptions parallelOptions = new ParallelOptions(); | |||
| parallelOptions.MaxDegreeOfParallelism = 1; | |||
| var input = np.array(new float[,] { { 0, 0 } }); | |||
| Parallel.For(0, 1, parallelOptions, i => | |||
| { | |||
| var clone = BuildModel(); | |||
| clone.load_weights(savefile); | |||
| //Predict something | |||
| clone.predict(input); | |||
| }); | |||
| } | |||
| Functional BuildModel() | |||
| { | |||
| tf.Context.reset_context(); | |||
| var inputs = keras.Input(shape: 2); | |||
| // 1st dense layer | |||
| var DenseLayer = keras.layers.Dense(1, activation: keras.activations.Sigmoid); | |||
| var outputs = DenseLayer.Apply(inputs); | |||
| // build keras model | |||
| Functional model = keras.Model(inputs, outputs, name: Guid.NewGuid().ToString()); | |||
| // show model summary | |||
| model.summary(); | |||
| // compile keras model into tensorflow's static graph | |||
| model.compile(loss: keras.losses.MeanSquaredError(name: Guid.NewGuid().ToString()), | |||
| optimizer: keras.optimizers.Adam(name: Guid.NewGuid().ToString()), | |||
| metrics: new[] { "accuracy" }); | |||
| return model; | |||
| } | |||
| } | |||
| } | |||
| @@ -16,7 +16,6 @@ namespace TensorFlowNET.UnitTest | |||
| /// <summary>Initializes a new instance of the <see cref="T:System.Object" /> class.</summary> | |||
| public EnforcedSinglethreadingTests() | |||
| { | |||
| ops.IsSingleThreaded = true; | |||
| } | |||
| [TestMethod, Ignore("Has to be tested manually.")] | |||
| @@ -24,8 +23,6 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| lock (_singlethreadLocker) | |||
| { | |||
| ops.IsSingleThreaded.Should().BeTrue(); | |||
| ops.uid(); //increment id by one | |||
| //the core method | |||