diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs index 4a7e0ed8..2a0d939e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using Google.Protobuf; using System.IO; using Tensorflow.Util; @@ -37,7 +38,9 @@ namespace Tensorflow using (var buffer = ToGraphDef(status)) { status.Check(true); - def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + // limit size to 250M, recursion to max 100 + var inputStream = CodedInputStream.CreateWithLimits(buffer.MemoryBlock.Stream(), 250 * 1024 * 1024, 100); + def = GraphDef.Parser.ParseFrom(inputStream); } // Strip the experimental library field iff it's empty. diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 690b4e0d..a89d94dc 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System; +using System.IO; +using System.Runtime.CompilerServices; using Tensorflow.Util; using static Tensorflow.Binding; @@ -39,6 +41,7 @@ namespace Tensorflow return this; } + [MethodImpl(MethodImplOptions.NoOptimization)] public static Session LoadFromSavedModel(string path) { lock (Locks.ProcessWide) @@ -50,20 +53,36 @@ namespace Tensorflow var tags = new string[] {"serve"}; var buffer = new TF_Buffer(); - var sess = c_api.TF_LoadSessionFromSavedModel(opt, - IntPtr.Zero, - path, - tags, - tags.Length, - graph, - ref buffer, - status); + IntPtr sess; + try + { + sess = c_api.TF_LoadSessionFromSavedModel(opt, + IntPtr.Zero, + path, + tags, + tags.Length, + graph, + ref buffer, + status); + status.Check(true); + } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) + { + status = new Status(); + sess = c_api.TF_LoadSessionFromSavedModel(opt, + IntPtr.Zero, + Path.GetFullPath(path), + tags, + tags.Length, + graph, + ref buffer, + status); + status.Check(true); + } // load graph bytes // var data = new byte[buffer.length]; // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ - status.Check(true); return new Session(sess, g: new Graph(graph)).as_default(); } diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs index b8b88c9c..1e4d829c 100644 --- a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs +++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs @@ -1,8 +1,11 @@ using System; using System.Collections.Generic; +using System.IO; +using System.Linq; using System.Runtime.InteropServices; using FluentAssertions; using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; using Tensorflow; using Tensorflow.Util; using static Tensorflow.Binding; @@ -260,7 +263,7 @@ namespace TensorFlowNET.UnitTest } } - + [TestMethod] public void TF_GraphOperationByName() { @@ -280,5 +283,46 @@ namespace TensorFlowNET.UnitTest } } } + + private static string modelPath = "./model/"; + + [TestMethod] + public void TF_GraphOperationByName_FromModel() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + Console.WriteLine(); + for (int j = 0; j < 100; j++) + { + var sess = Session.LoadFromSavedModel(modelPath).as_default(); + var inputs = new[] {"sp", "fuel"}; + + var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray(); + var outp = sess.graph.OperationByName("softmax_tensor").output; + + for (var i = 0; i < 100; i++) + { + { + var data = new float[96]; + FeedItem[] feeds = new FeedItem[2]; + + for (int f = 0; f < 2; f++) + feeds[f] = new FeedItem(inp[f], new NDArray(data)); + + try + { + sess.run(outp, feeds); + } catch (Exception ex) + { + Console.WriteLine(ex); + } + } + } + } + } + } } } \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 58ab60cf..d2ea6ebf 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -42,4 +42,10 @@ + + + PreserveNewest + + +