diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs
index 4a7e0ed8..2a0d939e 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using Google.Protobuf;
using System.IO;
using Tensorflow.Util;
@@ -37,7 +38,9 @@ namespace Tensorflow
using (var buffer = ToGraphDef(status))
{
status.Check(true);
- def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
+ // limit size to 250M, recursion to max 100
+ var inputStream = CodedInputStream.CreateWithLimits(buffer.MemoryBlock.Stream(), 250 * 1024 * 1024, 100);
+ def = GraphDef.Parser.ParseFrom(inputStream);
}
// Strip the experimental library field iff it's empty.
diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs
index 690b4e0d..a89d94dc 100644
--- a/src/TensorFlowNET.Core/Sessions/Session.cs
+++ b/src/TensorFlowNET.Core/Sessions/Session.cs
@@ -15,6 +15,8 @@
******************************************************************************/
using System;
+using System.IO;
+using System.Runtime.CompilerServices;
using Tensorflow.Util;
using static Tensorflow.Binding;
@@ -39,6 +41,7 @@ namespace Tensorflow
return this;
}
+ [MethodImpl(MethodImplOptions.NoOptimization)]
public static Session LoadFromSavedModel(string path)
{
lock (Locks.ProcessWide)
@@ -50,20 +53,36 @@ namespace Tensorflow
var tags = new string[] {"serve"};
var buffer = new TF_Buffer();
- var sess = c_api.TF_LoadSessionFromSavedModel(opt,
- IntPtr.Zero,
- path,
- tags,
- tags.Length,
- graph,
- ref buffer,
- status);
+ IntPtr sess;
+ try
+ {
+ sess = c_api.TF_LoadSessionFromSavedModel(opt,
+ IntPtr.Zero,
+ path,
+ tags,
+ tags.Length,
+ graph,
+ ref buffer,
+ status);
+ status.Check(true);
+ } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel"))
+ {
+ status = new Status();
+ sess = c_api.TF_LoadSessionFromSavedModel(opt,
+ IntPtr.Zero,
+ Path.GetFullPath(path),
+ tags,
+ tags.Length,
+ graph,
+ ref buffer,
+ status);
+ status.Check(true);
+ }
// load graph bytes
// var data = new byte[buffer.length];
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
- status.Check(true);
return new Session(sess, g: new Graph(graph)).as_default();
}
diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs
index b8b88c9c..1e4d829c 100644
--- a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs
+++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs
@@ -1,8 +1,11 @@
using System;
using System.Collections.Generic;
+using System.IO;
+using System.Linq;
using System.Runtime.InteropServices;
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
using Tensorflow;
using Tensorflow.Util;
using static Tensorflow.Binding;
@@ -260,7 +263,7 @@ namespace TensorFlowNET.UnitTest
}
}
-
+
[TestMethod]
public void TF_GraphOperationByName()
{
@@ -280,5 +283,46 @@ namespace TensorFlowNET.UnitTest
}
}
}
+
+ private static string modelPath = "./model/";
+
+ [TestMethod]
+ public void TF_GraphOperationByName_FromModel()
+ {
+ MultiThreadedUnitTestExecuter.Run(8, Core);
+
+ //the core method
+ void Core(int tid)
+ {
+ Console.WriteLine();
+ for (int j = 0; j < 100; j++)
+ {
+ var sess = Session.LoadFromSavedModel(modelPath).as_default();
+ var inputs = new[] {"sp", "fuel"};
+
+ var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray();
+ var outp = sess.graph.OperationByName("softmax_tensor").output;
+
+ for (var i = 0; i < 100; i++)
+ {
+ {
+ var data = new float[96];
+ FeedItem[] feeds = new FeedItem[2];
+
+ for (int f = 0; f < 2; f++)
+ feeds[f] = new FeedItem(inp[f], new NDArray(data));
+
+ try
+ {
+ sess.run(outp, feeds);
+ } catch (Exception ex)
+ {
+ Console.WriteLine(ex);
+ }
+ }
+ }
+ }
+ }
+ }
}
}
\ No newline at end of file
diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
index 58ab60cf..d2ea6ebf 100644
--- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
+++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
@@ -42,4 +42,10 @@
+
+
+ PreserveNewest
+
+
+