Browse Source

Session.LoadFromSavedModel: Added fallback for non-absolute path

tags/v0.12
Eli Belash 6 years ago
parent
commit
da70cdc790
1 changed files with 26 additions and 9 deletions
  1. +26
    -9
      src/TensorFlowNET.Core/Sessions/Session.cs

+ 26
- 9
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.IO;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -50,20 +51,36 @@ namespace Tensorflow
var tags = new string[] {"serve"}; var tags = new string[] {"serve"};
var buffer = new TF_Buffer(); var buffer = new TF_Buffer();


var sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero,
path,
tags,
tags.Length,
graph,
ref buffer,
status);
IntPtr sess;
try
{
sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero,
path,
tags,
tags.Length,
graph,
ref buffer,
status);
status.Check(true);
} catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel"))
{
status = new Status();
sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero,
Path.GetFullPath(path),
tags,
tags.Length,
graph,
ref buffer,
status);
status.Check(true);
}


// load graph bytes // load graph bytes
// var data = new byte[buffer.length]; // var data = new byte[buffer.length];
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length); // Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
status.Check(true);


return new Session(sess, g: new Graph(graph)).as_default(); return new Session(sess, g: new Graph(graph)).as_default();
} }


Loading…
Cancel
Save