Browse Source

fix default graph and operation issue when import model.

tags/v0.12
Oceania2018 6 years ago
parent
commit
6623162244
9 changed files with 89 additions and 44 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Framework/c_api_util.cs
  3. +25
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  4. +40
    -31
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +4
    -0
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  6. +4
    -1
      src/TensorFlowNET.Core/Operations/Operation.Implicit.cs
  7. +2
    -0
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  8. +10
    -8
      src/TensorFlowNET.Core/Sessions/Session.cs
  9. +1
    -1
      test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs

+ 2
- 2
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -30,8 +30,8 @@ namespace Tensorflow
get get
{ {
var data = new byte[buffer.length]; var data = new byte[buffer.length];
if (buffer.length > 0)
Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
if (data.Length > 0)
Marshal.Copy(buffer.data, data, 0, data.Length);
return data; return data;
} }
} }


+ 1
- 1
src/TensorFlowNET.Core/Framework/c_api_util.cs View File

@@ -128,7 +128,7 @@ namespace Tensorflow
IntPtr c_op; IntPtr c_op;
while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
{ {
yield return c_op;
yield return new Operation(c_op, graph);
} }
} }
} }


+ 25
- 0
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -38,6 +38,31 @@ namespace Tensorflow
return c_api.TF_NewOperation(_handle, opType, opName); return c_api.TF_NewOperation(_handle, opType, opName);
} }
public unsafe Operation[] ReturnOperations(IntPtr results)
{
TF_Operation return_oper_handle = new TF_Operation();
int num_return_opers = 0;
c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle);
Operation[] return_opers = new Operation[num_return_opers];
for (int i = 0; i < num_return_opers; i++)
{
var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i;
return_opers[i] = new Operation(*(IntPtr*)handle);
}
return return_opers;
}

public Operation OperationByName(string operName)
{
return c_api.TF_GraphOperationByName(_handle, operName);
}

public ITensorOrOperation[] get_operations()
{
return _nodes_by_name.Values.Select(x => x).ToArray();
}
/// <summary> /// <summary>
/// Returns the `Operation` with the given `name`. /// Returns the `Operation` with the given `name`.
/// ///


+ 40
- 31
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
@@ -72,7 +73,7 @@ namespace Tensorflow
all variables that are created during the construction of a graph. The caller all variables that are created during the construction of a graph. The caller
may define additional collections by specifying a new name. may define additional collections by specifying a new name.
*/ */
public partial class Graph : IPython, IDisposable
public partial class Graph : IPython, IDisposable, IEnumerable<Operation>
{ {
private IntPtr _handle; private IntPtr _handle;
private Dictionary<int, ITensorOrOperation> _nodes_by_id; private Dictionary<int, ITensorOrOperation> _nodes_by_id;
@@ -121,6 +122,10 @@ namespace Tensorflow
_nodes_by_name = new Dictionary<string, ITensorOrOperation>(); _nodes_by_name = new Dictionary<string, ITensorOrOperation>();
_names_in_use = new Dictionary<string, int>(); _names_in_use = new Dictionary<string, int>();
_graph_key = $"grap-key-{ops.uid()}/"; _graph_key = $"grap-key-{ops.uid()}/";
}
public void __enter__()
{
} }


public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
@@ -409,31 +414,6 @@ namespace Tensorflow
return return_outputs; return return_outputs;
} }


public unsafe Operation[] ReturnOperations(IntPtr results)
{
TF_Operation return_oper_handle = new TF_Operation();
int num_return_opers = 0;
c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle);
Operation[] return_opers = new Operation[num_return_opers];
for (int i = 0; i < num_return_opers; i++)
{
var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i;
return_opers[i] = new Operation(*(IntPtr*)handle);
}
return return_opers;
}

public Operation OperationByName(string operName)
{
return c_api.TF_GraphOperationByName(_handle, operName);
}

public ITensorOrOperation[] get_operations()
{
return _nodes_by_name.Values.Select(x => x).ToArray();
}

public string[] get_all_collection_keys() public string[] get_all_collection_keys()
{ {
return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
@@ -481,17 +461,46 @@ namespace Tensorflow
public Tensor get_tensor_by_name(string name) public Tensor get_tensor_by_name(string name)
{ {
return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false); return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false);
}

public void __enter__()
{
}
public TensorShape GetTensorShape(TF_Output output)
{
var status = new Status();
var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status);
status.Check();
if (ndim == -1)
return new TensorShape();
var dims = new long[ndim];
c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status);
status.Check();
return new TensorShape(dims.Select(x => (int)x).ToArray());
}
public override string ToString()
{
int len = 0;
return c_api.TF_GraphDebugString(_handle, out len);
} }


public void __exit__() public void __exit__()
{ {
}
}
private IEnumerable<Operation> GetEnumerable()
=> c_api_util.tf_operations(this);


IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator()
=> GetEnumerable().GetEnumerator();
IEnumerator IEnumerable.GetEnumerator()
{
throw new NotImplementedException();
}
public static implicit operator IntPtr(Graph graph) public static implicit operator IntPtr(Graph graph)
{ {
return graph._handle; return graph._handle;


+ 4
- 0
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -43,6 +43,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_DeleteImportGraphDefResults(IntPtr results); public static extern void TF_DeleteImportGraphDefResults(IntPtr results);


[DllImport(TensorFlowLibName)]
public static extern string TF_GraphDebugString(IntPtr graph, out int len);

[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status);


@@ -100,6 +103,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status);
/// <summary> /// <summary>
/// Iterate through the operations of a graph. /// Iterate through the operations of a graph.
/// </summary> /// </summary>


+ 4
- 1
src/TensorFlowNET.Core/Operations/Operation.Implicit.cs View File

@@ -23,7 +23,10 @@ namespace Tensorflow
/// </summary> /// </summary>
public partial class Operation public partial class Operation
{ {
public static implicit operator Operation(IntPtr handle) => new Operation(handle);
// make sure the new op is in the same graph instance
public static implicit operator Operation(IntPtr handle)
=> new Operation(handle);

public static implicit operator IntPtr(Operation op) => op._handle; public static implicit operator IntPtr(Operation op) => op._handle;
public static implicit operator Tensor(Operation op) => op.output; public static implicit operator Tensor(Operation op) => op.output;




+ 2
- 0
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -35,6 +35,8 @@ namespace Tensorflow


public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));


public TF_Output this[int index] => _tf_output(index);

public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
{ {
int size = Marshal.SizeOf<TF_Input>(); int size = Marshal.SizeOf<TF_Input>();


+ 10
- 8
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.Runtime.InteropServices;


namespace Tensorflow namespace Tensorflow
{ {
@@ -26,8 +27,8 @@ namespace Tensorflow


} }


public Session(IntPtr handle)
: base("", null, null)
public Session(IntPtr handle, Graph g = null)
: base("", g, null)
{ {
_session = handle; _session = handle;
} }
@@ -50,8 +51,10 @@ namespace Tensorflow
var graph = c_api.TF_NewGraph(); var graph = c_api.TF_NewGraph();
var status = new Status(); var status = new Status();
var opt = c_api.TF_NewSessionOptions(); var opt = c_api.TF_NewSessionOptions();

var tags = new string[] { "serve" }; var tags = new string[] { "serve" };
var buffer = new TF_Buffer(); var buffer = new TF_Buffer();

var sess = c_api.TF_LoadSessionFromSavedModel(opt, var sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero, IntPtr.Zero,
path, path,
@@ -61,14 +64,13 @@ namespace Tensorflow
ref buffer, ref buffer,
status); status);


//var bytes = new Buffer(buffer.data).Data;
//var meta_graph = MetaGraphDef.Parser.ParseFrom(bytes);

// load graph bytes
// var data = new byte[buffer.length];
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
status.Check(); status.Check();


new Graph(graph).as_default();

return sess;
return new Session(sess, g: new Graph(graph).as_default());
} }


public static implicit operator IntPtr(Session session) => session._session; public static implicit operator IntPtr(Session session) => session._session;


+ 1
- 1
test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs View File

@@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples
float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels));
print($"Accuracy: {acc.ToString("F4")}"); print($"Accuracy: {acc.ToString("F4")}");


return acc > 0.88;
return acc > 0.9;
}); });
} }




Loading…
Cancel
Save