Browse Source

tf.train.import_meta_graph can import CondContext.

tags/v0.9
Oceania2018 6 years ago
parent
commit
f4067f28f7
12 changed files with 1300 additions and 32 deletions
  1. +18
    -3
      src/TensorFlowNET.Core/Framework/meta_graph.py.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +40
    -8
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  4. +32
    -2
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  6. +4
    -4
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  7. +1172
    -0
      src/TensorFlowNET.Core/Protobuf/ControlFlow.cs
  8. +4
    -4
      src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs
  9. +5
    -2
      src/TensorFlowNET.Core/Protobuf/README.md
  10. +2
    -2
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  11. +1
    -1
      src/TensorFlowNET.Core/ops.py.cs
  12. +16
    -5
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

+ 18
- 3
src/TensorFlowNET.Core/Framework/meta_graph.py.cs View File

@@ -4,6 +4,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Tensorflow.Operations;
using static Tensorflow.CollectionDef;
using static Tensorflow.MetaGraphDef.Types;

@@ -95,15 +96,29 @@ namespace Tensorflow
}
else
{
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
foreach(var value in col.Value.BytesList.Value)
{
switch (col.Key)
{
case "cond_context":
var proto = CondContextDef.Parser.ParseFrom(value);
var condContext = new CondContext().from_proto(proto, import_scope);
graph.add_to_collection(col.Key, condContext);
break;
default:
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
}
}
}
break;
default:
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
}
}

var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope: scope_to_prepend_to_names) as List<RefVariable>;
var variables = graph.get_collection<RefVariable>(ops.GraphKeys.GLOBAL_VARIABLES,
scope: scope_to_prepend_to_names);
var var_list = new Dictionary<string, RefVariable>();
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v);



+ 5
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -412,6 +412,11 @@ namespace Tensorflow
return _collections.ContainsKey(name) ? _collections[name] : null;
}

public List<T> get_collection<T>(string name, string scope = null)
{
return _collections.ContainsKey(name) ? _collections[name] as List<T> : new List<T>();
}

public object get_collection_ref(string name)
{
if (!_collections.ContainsKey(name))


+ 40
- 8
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow.Operations
/// <summary>
/// The context for the conditional construct.
/// </summary>
public class CondContext : ControlFlowContext
public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext>
{


@@ -35,16 +35,20 @@ namespace Tensorflow.Operations
/// <param name="name">Name of the `CondContext` python object.</param>
/// <param name="context_def"></param>
/// <param name="import_scope"></param>
public CondContext(Tensor pred,
Tensor pivot,
int branch,
public CondContext(Tensor pred = null,
Tensor pivot = null,
int? branch = null,
string name = "cond_text",
object context_def = null,
CondContextDef context_def = null,
string import_scope = null)
{
if (pred == null && context_def == null) return;

_name = ops.get_default_graph().unique_name(name);
if (context_def != null)
throw new NotImplementedException("CondContext context_def is not null");
if (context_def != null)
{
_init_from_proto(context_def, import_scope: import_scope);
}
else
{
// Initializes the default fields.
@@ -61,6 +65,18 @@ namespace Tensorflow.Operations
}
}

private void _init_from_proto(CondContextDef context_def, string import_scope = null)
{
var g = ops.get_default_graph();
_name = ops.prepend_name_scope(context_def.ContextName, import_scope);
var p1 = ops.prepend_name_scope(context_def.PredName, import_scope);
_pred = g.as_graph_element(p1) as Tensor;
var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope);
_pivot = g.as_graph_element(p2) as Tensor;
_branch = context_def.Branch;
__init__(values_def: context_def.ValuesDef, import_scope: import_scope);
}

/// <summary>
/// Add `val` to the current context and its outer context recursively.
/// </summary>
@@ -230,6 +246,22 @@ namespace Tensorflow.Operations
public override void AddInnerOp(Operation resultOp)
{
throw new NotImplementedException();
}
}
public CondContextDef to_proto(string export_scope)
{
throw new NotImplementedException();
}
public CondContext from_proto(CondContextDef proto, string import_scope)
{
var ret = new CondContext(context_def: proto, import_scope: import_scope);
ret.Enter();
foreach (var nested_def in proto.NestedContexts)
throw new NotImplementedException("");
ret.Exit();
return ret;
}
}
}

+ 32
- 2
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -32,6 +32,8 @@ namespace Tensorflow.Operations
protected Stack<IControlFlowContext> _context_stack;
protected IControlFlowContext _outer_context;

protected Dictionary<string, ITensorOrOperation> _external_values;

public ControlFlowContext()
{
_context_stack = new Stack<IControlFlowContext>();
@@ -40,15 +42,43 @@ namespace Tensorflow.Operations
public string name { get => _name; }
protected string _name;

public void __init__()
public void __init__(ValuesDef values_def = null, string import_scope = null)
{

_outer_context = ops.get_default_graph()._get_control_flow_context();
if (values_def != null)
_init_values_from_proto(values_def, import_scope: import_scope);
}

public void __enter__()
{
}

/// <summary>
/// Initializes values and external_values from `ValuesDef` protocol buffer.
/// </summary>
/// <param name="values_def"></param>
/// <param name="import_scope"></param>
protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null)
{
_external_values = new Dictionary<string, ITensorOrOperation>();
foreach (var value in values_def.Values)
_values.Add(value);
var g = ops.get_default_graph();
foreach(var value in values_def.ExternalValues)
{
var k = ops.prepend_name_scope(value.Key, import_scope);
var v = value.Value;
_external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope));
}
var op_names = _values.Where(x => !_external_values.ContainsKey(x))
.Select(x => x.Split(':')[0])
.ToArray();
foreach (var op in op_names)
(g.as_graph_element(op) as Operation)._set_control_flow_context(this);
}

public void __exit__()
{
}


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -287,7 +287,7 @@ namespace Tensorflow
// Reset cached inputs.
_inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None
// TODO: implement below code dependencies
// c_api.TF_UpdateEdge(graph, output, input, status);
c_api.TF_UpdateEdge(graph, output, input, status);
}

private void _assert_same_graph(Tensor tensor)


+ 4
- 4
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -330,7 +330,7 @@ namespace Tensorflow
tensor.op.graph.prevent_fetching(tensor.op);

// Build the graph for the true branch in a new context.
var context_t = new CondContext(pred, pivot_1, branch: 1);
var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
ITensorOrOperation orig_res_t;
Tensor res_t;
try
@@ -343,7 +343,7 @@ namespace Tensorflow
context_t.Exit();
}
// Build the graph for the false branch in a new context.
var context_f = new CondContext(pred, pivot_2, branch: 0);
var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
ITensorOrOperation orig_res_f;
Tensor res_f;
try
@@ -411,13 +411,13 @@ namespace Tensorflow
tensor.op.graph.prevent_fetching(tensor.op);

// Build the graph for the true branch in a new context.
var context_t = new CondContext(pred, pivot_1, branch: 1);
var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
context_t.Enter();
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
context_t.Exit();

// Build the graph for the false branch in a new context.
var context_f = new CondContext(pred, pivot_2, branch: 0);
var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
context_f.Enter();
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
context_f.Exit();


+ 1172
- 0
src/TensorFlowNET.Core/Protobuf/ControlFlow.cs
File diff suppressed because it is too large
View File


+ 4
- 4
src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow
/// In order for a object to be serialized to and from MetaGraphDef,
/// the class must implement to_proto() and from_proto() methods
/// </summary>
public interface IProtoBuf
public interface IProtoBuf<TProtoDef, TDef>
{
string name { get; }

@@ -17,15 +17,15 @@ namespace Tensorflow
/// </summary>
/// <param name="export_scope"></param>
/// <returns></returns>
VariableDef to_proto(string export_scope);
TProtoDef to_proto(string export_scope);

/// <summary>
/// Returns a `Variable` object created from `variable_def`.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="variable_def"></param>
/// <param name="proto"></param>
/// <param name="import_scope"></param>
/// <returns></returns>
T from_proto<T>(VariableDef variable_def, string import_scope);
TDef from_proto(TProtoDef proto, string import_scope);
}
}

+ 5
- 2
src/TensorFlowNET.Core/Protobuf/README.md View File

@@ -1,10 +1,12 @@
### Download compiler from https://github.com/protocolbuffers/protobuf/releases
Work in command line

```shell
cd tensorflow

set SRC_DIR=D:/Projects/tensorflow
set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf

cd tensorflow

protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto
@@ -32,6 +34,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.prot
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/control_flow.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto
```


+ 2
- 2
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -7,7 +7,7 @@ using System.Text;

namespace Tensorflow
{
public partial class RefVariable : VariableV1, IProtoBuf
public partial class RefVariable : VariableV1, IProtoBuf<VariableDef, RefVariable>
{
public bool _in_graph_mode = true;
public Tensor _initial_value;
@@ -288,7 +288,7 @@ namespace Tensorflow
throw new NotImplementedException("to_proto RefVariable");
}

public T from_proto<T>(VariableDef variable_def, string import_scope)
public RefVariable from_proto(VariableDef proto, string import_scope)
{
throw new NotImplementedException();
}


+ 1
- 1
src/TensorFlowNET.Core/ops.py.cs View File

@@ -376,7 +376,7 @@ namespace Tensorflow
if (import_scope.EndsWith("/"))
import_scope = import_scope.Substring(0, import_scope.Length - 1);

throw new NotImplementedException("prepend_name_scope");
return $"{import_scope}/{name}";
}
else
return name;


+ 16
- 5
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -10,17 +10,28 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
[TestClass]
public class CondTestCases : PythonTest
{
[TestMethod]
public void testCondTrue()
{
with(tf.Session(), sess =>
var graph = tf.Graph().as_default();
with(tf.Session(graph), sess =>
{
var x = tf.constant(2);
var y = tf.constant(5);
var z = control_flow_ops.cond(tf.less(x, y),
() => tf.multiply(x, 17),
() => tf.add(y, 23));
var pred = tf.less(x, y);
Func<ITensorOrOperation> if_true = delegate
{
return tf.multiply(x, 17);
};
Func<ITensorOrOperation> if_false = delegate
{
return tf.add(y, 23);
};
var z = control_flow_ops.cond(pred, if_true, if_false);
int result = z.eval(sess);
assertEquals(result, 34);
});


Loading…
Cancel
Save