Browse Source

Merge pull request #217 from henon/master

Changes to Cond and Control Flow Context
tags/v0.9
Haiping GitHub 6 years ago
parent
commit
e7d7b2f6b6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 337 additions and 140 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  2. +25
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  3. +37
    -12
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  4. +122
    -16
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  5. +3
    -2
      src/TensorFlowNET.Core/Python.cs
  6. +144
    -109
      test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs
  7. +5
    -0
      test/TensorFlowNET.UnitTest/PythonTest.cs

+ 1
- 0
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -8,6 +8,7 @@ namespace Tensorflow
{ {
public partial class Graph public partial class Graph
{ {
// Current control flow context. It could be either CondContext or WhileContext
public IControlFlowContext _control_flow_context; public IControlFlowContext _control_flow_context;


// represents the nested with(...) statements // represents the nested with(...) statements


+ 25
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -64,6 +64,9 @@ namespace Tensorflow.Operations
} }
} }


/// <summary>
/// Add the subgraph defined by fn() to the graph.
/// </summary>
public (T, Tensor) BuildCondBranch<T>(Func<T> fn) public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
{ {
// Add the subgraph defined by fn() to the graph. // Add the subgraph defined by fn() to the graph.
@@ -71,10 +74,31 @@ namespace Tensorflow.Operations
var original_result = fn(); var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);


//TODO: port this chunck of missing code:
/*
if len(post_summaries) > len(pre_summaries):
new_summaries = post_summaries[len(pre_summaries):]
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
summary_ref[:] = pre_summaries
with ops.control_dependencies(new_summaries):
if original_result is None:
return no_op(), None
else:
original_result = nest.map_structure(array_ops.identity,
original_result)
*/
if (original_result == null)
return (original_result, null);

switch (original_result) switch (original_result)
{ {
case Operation[] results:
case Operation[] results:
// Python code:
// result = nest.map_structure(self._BuildCondTensor, original_result)
return (original_result, _BuildCondTensor(results)); return (original_result, _BuildCondTensor(results));
case Tensor t:
// TODO: should this be (original_result, t) instead?
return (original_result, _BuildCondTensor(new []{t.op}));
case float[] fv: case float[] fv:
var result = ops.convert_to_tensor(fv[0]); var result = ops.convert_to_tensor(fv[0]);
return (original_result, result ); return (original_result, result );


+ 37
- 12
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -3,7 +3,24 @@ using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Operations namespace Tensorflow.Operations
{
{
/// <summary>
/// The base class for control flow context.
///
/// The usage pattern is a sequence of(Enter, Exit) followed by a final
/// ExitResult.
///
/// We maintain the following state for control flow contexts during graph
/// construction:
/// 1. graph has _control_flow_context: the current context used to
/// construct new nodes.Changed by ctxt.Enter() and ctxt.Exit()
/// 2. op has _control_flow_context: the context to which the op belongs.
/// Set at the time the op is created.Immutable.
/// 3. A ControlFlowContext has _outer_context: the context in which this
/// context is created.Set at the time a context is created.Immutable.
/// 4. A ControlFlowContext has _context_stack.
/// Pushed and popped by ctxt.Enter() and ctxt.Exit()
/// </summary>
public abstract class ControlFlowContext : IPython, IControlFlowContext public abstract class ControlFlowContext : IPython, IControlFlowContext
{ {
/// <summary> /// <summary>
@@ -17,6 +34,8 @@ namespace Tensorflow.Operations
_context_stack = new Stack<IControlFlowContext>(); _context_stack = new Stack<IControlFlowContext>();
} }


public string name { get; set; }

public void __init__() public void __init__()
{ {


@@ -26,6 +45,13 @@ namespace Tensorflow.Operations
{ {
} }


public void __exit__()
{
}

/// <summary>
/// Enter this control flow context.
/// </summary>
public virtual void Enter() public virtual void Enter()
{ {
var graph = ops.get_default_graph(); var graph = ops.get_default_graph();
@@ -33,6 +59,16 @@ namespace Tensorflow.Operations
graph._set_control_flow_context(this); graph._set_control_flow_context(this);
} }


/// <summary>
/// Exit this control flow context.
/// </summary>
public virtual void Exit()
{
var graph = ops.get_default_graph();
var last_context = _context_stack.Pop();
graph._set_control_flow_context(last_context);
}

public void AddOp(Operation op) public void AddOp(Operation op)
{ {
_AddOpInternal(op); _AddOpInternal(op);
@@ -56,17 +92,6 @@ namespace Tensorflow.Operations
var internal_control_inputs = op.control_inputs; var internal_control_inputs = op.control_inputs;
} }


public void Exit()
{
var graph = ops.get_default_graph();
var last_context = _context_stack.Pop();
graph._set_control_flow_context(last_context);
}

public void __exit__()
{
}

public void Dispose() public void Dispose()
{ {
} }


+ 122
- 16
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -185,8 +185,50 @@ namespace Tensorflow
ops.colocate_with(data, ignore_existing: true); ops.colocate_with(data, ignore_existing: true);


return @switch(data, pred, name: name); return @switch(data, pred, name: name);
}

}
/// <summary>
/// Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
///
/// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
/// `false_fn` must have the same non-zero number and type of outputs.
///
/// **WARNING**: Any Tensors or Operations created outside of `true_fn` and
/// `false_fn` will be executed regardless of which branch is selected at runtime.
///
/// Although this behavior is consistent with the dataflow model of TensorFlow,
/// it has frequently surprised users who expected a lazier semantics.
/// Consider the following simple program:
///
/// z = tf.multiply(a, b)
/// result = tf.cond(x &lt; y, ()=> tf.add(x, z), ()=> tf.square(y))
///
/// If `x&lt;y`, the `tf.add` operation will be executed and `tf.square`
/// operation will not be executed.Since `z` is needed for at least one
/// branch of the `cond`, the `tf.multiply` operation is always executed,
/// unconditionally.
///
/// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
/// call to `cond`, and not at all during `Session.run()`). `cond`
/// stitches together the graph fragments created during the `true_fn` and
/// `false_fn` calls with some additional graph nodes to ensure that the right
/// branch gets executed depending on the value of `pred`.
///
/// `tf.cond` supports nested structures as implemented in
/// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
/// same(possibly nested) value structure of lists, tuples, and/or named tuples.
/// Singleton lists and tuples form the only exceptions to this: when returned by
/// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
/// This behavior is disabled by passing `strict= True`.
/// </summary>
/// <param name="pred"> A scalar determining whether to return the result of `true_fn` or
/// `false_fn`.</param>
/// <param name="true_fn">The callable to be performed if pred is true.</param>
/// <param name="false_fn">The callable to be performed if pred is false.</param>
/// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param>
/// <param name="name">Optional name prefix for the returned tensors.</param>
/// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the
/// callables return a singleton list, the element is extracted from the list.</returns>
public static Tensor cond(Tensor pred, public static Tensor cond(Tensor pred,
Func<ITensorOrOperation> true_fn = null, Func<ITensorOrOperation> true_fn = null,
Func<ITensorOrOperation> false_fn = null, Func<ITensorOrOperation> false_fn = null,
@@ -195,6 +237,37 @@ namespace Tensorflow
{ {
return with(ops.name_scope(name, "cond", new { pred }), delegate return with(ops.name_scope(name, "cond", new { pred }), delegate
{ {
// TODO: here a chunk of original code is missing
/*
if fn1 is not None:
if true_fn is not None:
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
true_fn = fn1
elif true_fn is None:
raise TypeError("cond(): true_fn argument required")
if fn2 is not None:
if false_fn is not None:
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
false_fn = fn2
elif false_fn is None:
raise TypeError("cond(): false_fn argument required")
if not callable(true_fn):
raise TypeError("true_fn must be callable.")
if not callable(false_fn):
raise TypeError("false_fn must be callable.")

with ops.name_scope(name, "cond", [pred]):
if context.executing_eagerly():
if pred:
return _UnpackIfSingleton(true_fn())
return _UnpackIfSingleton(false_fn())
# Add the Switch to the graph.
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool")
*/

// Add the Switch to the graph. // Add the Switch to the graph.
var (p_2, p_1) = @switch(pred, pred); var (p_2, p_1) = @switch(pred, pred);
var pivot_1 = array_ops.identity(p_1, name: "switch_t"); var pivot_1 = array_ops.identity(p_1, name: "switch_t");
@@ -207,30 +280,63 @@ namespace Tensorflow


// Build the graph for the true branch in a new context. // Build the graph for the true branch in a new context.
var context_t = new CondContext(pred, pivot_1, branch: 1); var context_t = new CondContext(pred, pivot_1, branch: 1);
context_t.Enter();
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
context_t.Exit();

ITensorOrOperation orig_res_t;
Tensor res_t;
try
{
context_t.Enter();
(orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
}
finally
{
context_t.Exit();
}
// Build the graph for the false branch in a new context. // Build the graph for the false branch in a new context.
var context_f = new CondContext(pred, pivot_2, branch: 0); var context_f = new CondContext(pred, pivot_2, branch: 0);
context_f.Enter();
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
context_f.Exit();
ITensorOrOperation orig_res_f;
Tensor res_f;
try
{
context_f.Enter();
(orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
}
finally
{
context_f.Exit();
}


var res_t_flat = res_t;
var res_f_flat = res_f;
//TODO: missing original code
//if not strict:
// orig_res_t = _UnpackIfSingleton(orig_res_t)
// orig_res_f = _UnpackIfSingleton(orig_res_f)
/*
# Check that the return values of the two branches have the same structure.
try:
nest.assert_same_structure(orig_res_t, orig_res_f)
except TypeError as e:
raise TypeError(
"Incompatible return types of true_fn and false_fn: {}".format(e))
except ValueError as e:
raise ValueError(
"Incompatible return values of true_fn and false_fn: {}".format(e))
# Add the final merge to the graph.
if not res_t:
raise ValueError("true_fn and false_fn must return at least one result.
*/
var res_t_flat = new[] { res_t };
var res_f_flat = new[] { res_f };


return new Tensor(IntPtr.Zero);
/*var merges = zip(res_f_flat, res_t_flat)
var merges = zip(res_f_flat, res_t_flat)
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
.ToArray(); .ToArray();


merges = _convert_flows_to_tensorarrays(orig_res_t, merges);

merges = _convert_flows_to_tensorarrays(new [] { orig_res_t}, merges);
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f);


return merges;*/
return merges[0];
}); });
} }




+ 3
- 2
src/TensorFlowNET.Core/Python.cs View File

@@ -1,5 +1,6 @@
using NumSharp; using NumSharp;
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Linq; using System.Linq;
@@ -17,8 +18,8 @@ namespace Tensorflow
Console.WriteLine(obj.ToString()); Console.WriteLine(obj.ToString());
} }


protected int len(Array a)
=> a.Length;
protected int len<T>(IEnumerable<T> a)
=> a.Count();


protected IEnumerable<int> range(int end) protected IEnumerable<int> range(int end)
{ {


+ 144
- 109
test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs View File

@@ -61,115 +61,150 @@ namespace TensorFlowNET.UnitTest
self.assertEqual(op4.name, "myop_1_1"); self.assertEqual(op4.name, "myop_1_1");
}); });
} }
[Ignore("Something is not right, Switch gets not inserted correctly?")]
[TestMethod]
public void TestCond()
{
var graph = tf.Graph().as_default();
with<Graph>(graph, g =>
{
var x = constant_op.constant(10);
var true_fn = new Func<Tensor>(() =>
{
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
var new_ops = g._add_new_tf_operations();
self.assertEqual(len(new_ops), 1);
return x;
});
control_flow_ops.cond(x < 10, true_fn, () => x);
var op = g.get_operation_by_name("cond/myop");
self.assertIsNotNone(op);
self.assertEqual(op.name, "cond/myop");
self.assertEqual(op.type, "Identity");
//self.assertEqual(op.outputs, new object[0]);
var op_input = op.inputs[0].op;
self.assertEqual(op_input.type, "Switch");
self.assertEqual(op_input.inputs[0], x);
self.assertEqual(op.graph, g);
self.assertIsNotNone(op._get_control_flow_context());
// TODO: op._get_control_flow_context().name not implemented
//self.assertEqual(op._get_control_flow_context().name, "cond/cond_text");
});
/*
@test_util.run_v1_only("b/120545219")
def testCond(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def true_fn():
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "cond/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return x
control_flow_ops.cond(x < 10, true_fn, lambda: x)
op = g.get_operation_by_name("cond/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "cond/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Switch")
self.assertEqual(op_input.inputs[0], x)
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"cond/cond_text")
# pylint: enable=protected-access
*/
}
/* /*
@test_util.run_v1_only("b/120545219")
def testCond(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def true_fn():
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "cond/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return x
control_flow_ops.cond(x < 10, true_fn, lambda: x)
op = g.get_operation_by_name("cond/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "cond/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Switch")
self.assertEqual(op_input.inputs[0], x)
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"cond/cond_text")
# pylint: enable=protected-access
@test_util.run_v1_only("b/120545219")
def testWhileLoop(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "myloop/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Enter")
self.assertEqual(list(op_input.inputs), [x])
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"myloop/while_context")
# pylint: enable=protected-access
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithInternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
c = constant_op.constant(1.0, name="c")
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
c = g.get_operation_by_name("myloop/c")
self.assertIsNotNone(c)
# Internal control dep is preserved
self.assertEqual(op.control_inputs, [c])
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithExternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
c = constant_op.constant(1.0)
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
# External control dep is removed and replaced with internal control dep
self.assertNotEqual(op.control_inputs[0], c.op)
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
*/
@test_util.run_v1_only("b/120545219")
def testWhileLoop(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "myloop/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Enter")
self.assertEqual(list(op_input.inputs), [x])
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"myloop/while_context")
# pylint: enable=protected-access
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithInternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
c = constant_op.constant(1.0, name="c")
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
c = g.get_operation_by_name("myloop/c")
self.assertIsNotNone(c)
# Internal control dep is preserved
self.assertEqual(op.control_inputs, [c])
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithExternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
c = constant_op.constant(1.0)
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
# External control dep is removed and replaced with internal control dep
self.assertNotEqual(op.control_inputs[0], c.op)
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
*/
} }
} }

+ 5
- 0
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -29,6 +29,11 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(expected, given); Assert.AreEqual(expected, given);
} }
public void assertIsNotNone(object given)
{
Assert.IsNotNull(given);
}
protected PythonTest self { get => this; } protected PythonTest self { get => this; }
} }
} }

Loading…
Cancel
Save