Browse Source

Session: LoadFromSavedModel Fixed multithreading issue and added as_default() to the created session, fixed #380

tags/v0.12
Eli Belash 6 years ago
parent
commit
b2512954bf
1 changed files with 27 additions and 23 deletions
  1. +27
    -23
      src/TensorFlowNET.Core/Sessions/Session.cs

+ 27
- 23
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -40,29 +41,32 @@ namespace Tensorflow

public static Session LoadFromSavedModel(string path)
{
var graph = c_api.TF_NewGraph();
var status = new Status();
var opt = new SessionOptions();

var tags = new string[] { "serve" };
var buffer = new TF_Buffer();

var sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero,
path,
tags,
tags.Length,
graph,
ref buffer,
status);

// load graph bytes
// var data = new byte[buffer.length];
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
status.Check(true);

return new Session(sess, g: new Graph(graph).as_default());
lock (Locks.ProcessWide)
{
var graph = c_api.TF_NewGraph();
var status = new Status();
var opt = new SessionOptions();

var tags = new string[] {"serve"};
var buffer = new TF_Buffer();

var sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero,
path,
tags,
tags.Length,
graph,
ref buffer,
status);

// load graph bytes
// var data = new byte[buffer.length];
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
status.Check(true);

return new Session(sess, g: new Graph(graph).as_default()).as_default();
}
}

public static implicit operator IntPtr(Session session) => session._handle;


Loading…
Cancel
Save