| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Google.Protobuf; | |||||
| using System.IO; | using System.IO; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| @@ -37,7 +38,9 @@ namespace Tensorflow | |||||
| using (var buffer = ToGraphDef(status)) | using (var buffer = ToGraphDef(status)) | ||||
| { | { | ||||
| status.Check(true); | status.Check(true); | ||||
| def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| // limit size to 250M, recursion to max 100 | |||||
| var inputStream = CodedInputStream.CreateWithLimits(buffer.MemoryBlock.Stream(), 250 * 1024 * 1024, 100); | |||||
| def = GraphDef.Parser.ParseFrom(inputStream); | |||||
| } | } | ||||
| // Strip the experimental library field iff it's empty. | // Strip the experimental library field iff it's empty. | ||||
| @@ -15,6 +15,8 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.IO; | |||||
| using System.Runtime.CompilerServices; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -39,6 +41,7 @@ namespace Tensorflow | |||||
| return this; | return this; | ||||
| } | } | ||||
| [MethodImpl(MethodImplOptions.NoOptimization)] | |||||
| public static Session LoadFromSavedModel(string path) | public static Session LoadFromSavedModel(string path) | ||||
| { | { | ||||
| lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
| @@ -50,20 +53,36 @@ namespace Tensorflow | |||||
| var tags = new string[] {"serve"}; | var tags = new string[] {"serve"}; | ||||
| var buffer = new TF_Buffer(); | var buffer = new TF_Buffer(); | ||||
| var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
| IntPtr.Zero, | |||||
| path, | |||||
| tags, | |||||
| tags.Length, | |||||
| graph, | |||||
| ref buffer, | |||||
| status); | |||||
| IntPtr sess; | |||||
| try | |||||
| { | |||||
| sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
| IntPtr.Zero, | |||||
| path, | |||||
| tags, | |||||
| tags.Length, | |||||
| graph, | |||||
| ref buffer, | |||||
| status); | |||||
| status.Check(true); | |||||
| } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | |||||
| { | |||||
| status = new Status(); | |||||
| sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
| IntPtr.Zero, | |||||
| Path.GetFullPath(path), | |||||
| tags, | |||||
| tags.Length, | |||||
| graph, | |||||
| ref buffer, | |||||
| status); | |||||
| status.Check(true); | |||||
| } | |||||
| // load graph bytes | // load graph bytes | ||||
| // var data = new byte[buffer.length]; | // var data = new byte[buffer.length]; | ||||
| // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | ||||
| // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | ||||
| status.Check(true); | |||||
| return new Session(sess, g: new Graph(graph)).as_default(); | return new Session(sess, g: new Graph(graph)).as_default(); | ||||
| } | } | ||||
| @@ -1,8 +1,11 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using FluentAssertions; | using FluentAssertions; | ||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using NumSharp; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -260,7 +263,7 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| public void TF_GraphOperationByName() | public void TF_GraphOperationByName() | ||||
| { | { | ||||
| @@ -280,5 +283,46 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| private static string modelPath = "./model/"; | |||||
| [TestMethod] | |||||
| public void TF_GraphOperationByName_FromModel() | |||||
| { | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| Console.WriteLine(); | |||||
| for (int j = 0; j < 100; j++) | |||||
| { | |||||
| var sess = Session.LoadFromSavedModel(modelPath).as_default(); | |||||
| var inputs = new[] {"sp", "fuel"}; | |||||
| var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray(); | |||||
| var outp = sess.graph.OperationByName("softmax_tensor").output; | |||||
| for (var i = 0; i < 100; i++) | |||||
| { | |||||
| { | |||||
| var data = new float[96]; | |||||
| FeedItem[] feeds = new FeedItem[2]; | |||||
| for (int f = 0; f < 2; f++) | |||||
| feeds[f] = new FeedItem(inp[f], new NDArray(data)); | |||||
| try | |||||
| { | |||||
| sess.run(outp, feeds); | |||||
| } catch (Exception ex) | |||||
| { | |||||
| Console.WriteLine(ex); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -42,4 +42,10 @@ | |||||
| <ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" /> | <ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | |||||
| <None Update="model\saved_model.pb"> | |||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
| </None> | |||||
| </ItemGroup> | |||||
| </Project> | </Project> | ||||