| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using System.IO; | |||
| using Tensorflow.Util; | |||
| @@ -37,7 +38,9 @@ namespace Tensorflow | |||
| using (var buffer = ToGraphDef(status)) | |||
| { | |||
| status.Check(true); | |||
| def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| // limit size to 250M, recursion to max 100 | |||
| var inputStream = CodedInputStream.CreateWithLimits(buffer.MemoryBlock.Stream(), 250 * 1024 * 1024, 100); | |||
| def = GraphDef.Parser.ParseFrom(inputStream); | |||
| } | |||
| // Strip the experimental library field iff it's empty. | |||
| @@ -15,6 +15,8 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.IO; | |||
| using System.Runtime.CompilerServices; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| @@ -39,6 +41,7 @@ namespace Tensorflow | |||
| return this; | |||
| } | |||
| [MethodImpl(MethodImplOptions.NoOptimization)] | |||
| public static Session LoadFromSavedModel(string path) | |||
| { | |||
| lock (Locks.ProcessWide) | |||
| @@ -50,20 +53,36 @@ namespace Tensorflow | |||
| var tags = new string[] {"serve"}; | |||
| var buffer = new TF_Buffer(); | |||
| var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
| IntPtr.Zero, | |||
| path, | |||
| tags, | |||
| tags.Length, | |||
| graph, | |||
| ref buffer, | |||
| status); | |||
| IntPtr sess; | |||
| try | |||
| { | |||
| sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
| IntPtr.Zero, | |||
| path, | |||
| tags, | |||
| tags.Length, | |||
| graph, | |||
| ref buffer, | |||
| status); | |||
| status.Check(true); | |||
| } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | |||
| { | |||
| status = new Status(); | |||
| sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
| IntPtr.Zero, | |||
| Path.GetFullPath(path), | |||
| tags, | |||
| tags.Length, | |||
| graph, | |||
| ref buffer, | |||
| status); | |||
| status.Check(true); | |||
| } | |||
| // load graph bytes | |||
| // var data = new byte[buffer.length]; | |||
| // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||
| // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||
| status.Check(true); | |||
| return new Session(sess, g: new Graph(graph)).as_default(); | |||
| } | |||
| @@ -1,8 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using FluentAssertions; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using NumSharp; | |||
| using Tensorflow; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| @@ -260,7 +263,7 @@ namespace TensorFlowNET.UnitTest | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void TF_GraphOperationByName() | |||
| { | |||
| @@ -280,5 +283,46 @@ namespace TensorFlowNET.UnitTest | |||
| } | |||
| } | |||
| } | |||
| private static string modelPath = "./model/"; | |||
| [TestMethod] | |||
| public void TF_GraphOperationByName_FromModel() | |||
| { | |||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||
| //the core method | |||
| void Core(int tid) | |||
| { | |||
| Console.WriteLine(); | |||
| for (int j = 0; j < 100; j++) | |||
| { | |||
| var sess = Session.LoadFromSavedModel(modelPath).as_default(); | |||
| var inputs = new[] {"sp", "fuel"}; | |||
| var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray(); | |||
| var outp = sess.graph.OperationByName("softmax_tensor").output; | |||
| for (var i = 0; i < 100; i++) | |||
| { | |||
| { | |||
| var data = new float[96]; | |||
| FeedItem[] feeds = new FeedItem[2]; | |||
| for (int f = 0; f < 2; f++) | |||
| feeds[f] = new FeedItem(inp[f], new NDArray(data)); | |||
| try | |||
| { | |||
| sess.run(outp, feeds); | |||
| } catch (Exception ex) | |||
| { | |||
| Console.WriteLine(ex); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -42,4 +42,10 @@ | |||
| <ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <None Update="model\saved_model.pb"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| </ItemGroup> | |||
| </Project> | |||