diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index 0ac0becf..42cf1a17 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -8,6 +8,7 @@ namespace Tensorflow { public partial class Graph { + // Current control flow context. It could be either CondContext or WhileContext public IControlFlowContext _control_flow_context; // represents the nested with(...) statements diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 0385341a..a0d84e89 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -64,6 +64,9 @@ namespace Tensorflow.Operations } } + /// + /// Add the subgraph defined by fn() to the graph. + /// public (T, Tensor) BuildCondBranch(Func fn) { // Add the subgraph defined by fn() to the graph. @@ -71,6 +74,22 @@ namespace Tensorflow.Operations var original_result = fn(); var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + //TODO: port this chunck of missing code: + /* + if len(post_summaries) > len(pre_summaries): + new_summaries = post_summaries[len(pre_summaries):] + summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access + summary_ref[:] = pre_summaries + with ops.control_dependencies(new_summaries): + if original_result is None: + return no_op(), None + else: + original_result = nest.map_structure(array_ops.identity, + original_result) + */ + if (original_result == null) + return (original_result, null); + switch (original_result) { case Operation[] results: diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 8776f171..9d039d58 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -3,7 +3,24 @@ using System.Collections.Generic; using System.Text; namespace Tensorflow.Operations -{ +{ + /// + /// The base class for control flow context. + /// + /// The usage pattern is a sequence of(Enter, Exit) followed by a final + /// ExitResult. + /// + /// We maintain the following state for control flow contexts during graph + /// construction: + /// 1. graph has _control_flow_context: the current context used to + /// construct new nodes.Changed by ctxt.Enter() and ctxt.Exit() + /// 2. op has _control_flow_context: the context to which the op belongs. + /// Set at the time the op is created.Immutable. + /// 3. A ControlFlowContext has _outer_context: the context in which this + /// context is created.Set at the time a context is created.Immutable. + /// 4. A ControlFlowContext has _context_stack. + /// Pushed and popped by ctxt.Enter() and ctxt.Exit() + /// public abstract class ControlFlowContext : IPython, IControlFlowContext { /// @@ -17,6 +34,8 @@ namespace Tensorflow.Operations _context_stack = new Stack(); } + public string name { get; set; } + public void __init__() { @@ -26,6 +45,13 @@ namespace Tensorflow.Operations { } + public void __exit__() + { + } + + /// + /// Enter this control flow context. + /// public virtual void Enter() { var graph = ops.get_default_graph(); @@ -33,6 +59,16 @@ namespace Tensorflow.Operations graph._set_control_flow_context(this); } + /// + /// Exit this control flow context. + /// + public virtual void Exit() + { + var graph = ops.get_default_graph(); + var last_context = _context_stack.Pop(); + graph._set_control_flow_context(last_context); + } + public void AddOp(Operation op) { _AddOpInternal(op); @@ -56,17 +92,6 @@ namespace Tensorflow.Operations var internal_control_inputs = op.control_inputs; } - public void Exit() - { - var graph = ops.get_default_graph(); - var last_context = _context_stack.Pop(); - graph._set_control_flow_context(last_context); - } - - public void __exit__() - { - } - public void Dispose() { } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 0b7afced..5c27d24d 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -187,6 +187,48 @@ namespace Tensorflow return @switch(data, pred, name: name); } + /// + /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`. + /// + /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and + /// `false_fn` must have the same non-zero number and type of outputs. + /// + /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and + /// `false_fn` will be executed regardless of which branch is selected at runtime. + /// + /// Although this behavior is consistent with the dataflow model of TensorFlow, + /// it has frequently surprised users who expected a lazier semantics. + /// Consider the following simple program: + /// + /// z = tf.multiply(a, b) + /// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y)) + /// + /// If `x<y`, the `tf.add` operation will be executed and `tf.square` + /// operation will not be executed.Since `z` is needed for at least one + /// branch of the `cond`, the `tf.multiply` operation is always executed, + /// unconditionally. + /// + /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the + /// call to `cond`, and not at all during `Session.run()`). `cond` + /// stitches together the graph fragments created during the `true_fn` and + /// `false_fn` calls with some additional graph nodes to ensure that the right + /// branch gets executed depending on the value of `pred`. + /// + /// `tf.cond` supports nested structures as implemented in + /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the + /// same(possibly nested) value structure of lists, tuples, and/or named tuples. + /// Singleton lists and tuples form the only exceptions to this: when returned by + /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. + /// This behavior is disabled by passing `strict= True`. + /// + /// A scalar determining whether to return the result of `true_fn` or + /// `false_fn`. + /// The callable to be performed if pred is true. + /// The callable to be performed if pred is false. + /// A boolean that enables/disables 'strict' mode; see above. + /// Optional name prefix for the returned tensors. + /// Tensors returned by the call to either `true_fn` or `false_fn`. If the + /// callables return a singleton list, the element is extracted from the list. public static Tensor cond(Tensor pred, Func true_fn = null, Func false_fn = null, @@ -195,6 +237,37 @@ namespace Tensorflow { return with(ops.name_scope(name, "cond", new { pred }), delegate { + // TODO: here a chunk of original code is missing + /* + if fn1 is not None: + if true_fn is not None: + raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") + true_fn = fn1 + elif true_fn is None: + raise TypeError("cond(): true_fn argument required") + if fn2 is not None: + if false_fn is not None: + raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") + false_fn = fn2 + elif false_fn is None: + raise TypeError("cond(): false_fn argument required") + + if not callable(true_fn): + raise TypeError("true_fn must be callable.") + if not callable(false_fn): + raise TypeError("false_fn must be callable.") + + with ops.name_scope(name, "cond", [pred]): + if context.executing_eagerly(): + if pred: + return _UnpackIfSingleton(true_fn()) + return _UnpackIfSingleton(false_fn()) + + # Add the Switch to the graph. + if isinstance(pred, bool): + raise TypeError("pred must not be a Python bool") + */ + // Add the Switch to the graph. var (p_2, p_1) = @switch(pred, pred); var pivot_1 = array_ops.identity(p_1, name: "switch_t"); @@ -207,15 +280,45 @@ namespace Tensorflow // Build the graph for the true branch in a new context. var context_t = new CondContext(pred, pivot_1, branch: 1); - context_t.Enter(); - var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); - context_t.Exit(); - + ITensorOrOperation orig_res_t; + Tensor res_t; + try + { + context_t.Enter(); + (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); + } + finally + { + context_t.Exit(); + } // Build the graph for the false branch in a new context. var context_f = new CondContext(pred, pivot_2, branch: 0); - context_f.Enter(); - var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); - context_f.Exit(); + ITensorOrOperation orig_res_f; + Tensor res_f; + try + { + context_f.Enter(); + (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); + } + finally + { + context_f.Exit(); + } + + //TODO: missing original code + //if not strict: + // orig_res_t = _UnpackIfSingleton(orig_res_t) + // orig_res_f = _UnpackIfSingleton(orig_res_f) + /* + # Check that the return values of the two branches have the same structure. + try: + nest.assert_same_structure(orig_res_t, orig_res_f) + except TypeError as e: + raise TypeError( + "Incompatible return types of true_fn and false_fn: {}".format(e)) + except ValueError as e: + raise ValueError( + "Incompatible return values of true_fn and false_fn: {}".format(e)) var res_t_flat = new Tensor[] { res_t }; var res_f_flat = new Tensor[] { res_f }; diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index ce859ec8..eaa57681 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -1,5 +1,6 @@ using NumSharp; using System; +using System.Collections; using System.Collections.Generic; using System.ComponentModel; using System.Linq; @@ -17,8 +18,8 @@ namespace Tensorflow Console.WriteLine(obj.ToString()); } - protected int len(Array a) - => a.Length; + protected int len(IEnumerable a) + => a.Count(); protected IEnumerable range(int end) { diff --git a/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs index 88e4da9d..78c9cd45 100644 --- a/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs @@ -61,115 +61,150 @@ namespace TensorFlowNET.UnitTest self.assertEqual(op4.name, "myop_1_1"); }); } + + [Ignore("Something is not right, Switch gets not inserted correctly?")] + [TestMethod] + public void TestCond() + { + var graph = tf.Graph().as_default(); + with(graph, g => + { + var x = constant_op.constant(10); + + var true_fn = new Func(() => + { + var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); + var new_ops = g._add_new_tf_operations(); + self.assertEqual(len(new_ops), 1); + return x; + }); + + control_flow_ops.cond(x < 10, true_fn, () => x); + + var op = g.get_operation_by_name("cond/myop"); + self.assertIsNotNone(op); + self.assertEqual(op.name, "cond/myop"); + self.assertEqual(op.type, "Identity"); + //self.assertEqual(op.outputs, new object[0]); + var op_input = op.inputs[0].op; + self.assertEqual(op_input.type, "Switch"); + self.assertEqual(op_input.inputs[0], x); + self.assertEqual(op.graph, g); + self.assertIsNotNone(op._get_control_flow_context()); + // TODO: op._get_control_flow_context().name not implemented + //self.assertEqual(op._get_control_flow_context().name, "cond/cond_text"); + }); + /* + @test_util.run_v1_only("b/120545219") + def testCond(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + + def true_fn(): + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "cond/myop"), [x], []) + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return x + + control_flow_ops.cond(x < 10, true_fn, lambda: x) + + op = g.get_operation_by_name("cond/myop") + self.assertIsNotNone(op) + self.assertEqual(op.name, "cond/myop") + self.assertEqual(op.type, "IntInput") + self.assertEqual(op.outputs, []) + op_input = op.inputs[0].op + self.assertEqual(op_input.type, "Switch") + self.assertEqual(op_input.inputs[0], x) + self.assertEqual(op.graph, g) + # pylint: disable=protected-access + self.assertIsNotNone(op._get_control_flow_context()) + self.assertEqual(op._get_control_flow_context().name, + "cond/cond_text") + # pylint: enable=protected-access + */ + } /* - @test_util.run_v1_only("b/120545219") - def testCond(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - - def true_fn(): - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "cond/myop"), [x], []) - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return x - - control_flow_ops.cond(x < 10, true_fn, lambda: x) - - op = g.get_operation_by_name("cond/myop") - self.assertIsNotNone(op) - self.assertEqual(op.name, "cond/myop") - self.assertEqual(op.type, "IntInput") - self.assertEqual(op.outputs, []) - op_input = op.inputs[0].op - self.assertEqual(op_input.type, "Switch") - self.assertEqual(op_input.inputs[0], x) - self.assertEqual(op.graph, g) - # pylint: disable=protected-access - self.assertIsNotNone(op._get_control_flow_context()) - self.assertEqual(op._get_control_flow_context().name, - "cond/cond_text") - # pylint: enable=protected-access - - @test_util.run_v1_only("b/120545219") - def testWhileLoop(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - - def body(i): - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "myloop/myop"), [x], []) - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return i - - control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") - - op = g.get_operation_by_name("myloop/myop") - self.assertIsNotNone(op) - self.assertEqual(op.name, "myloop/myop") - self.assertEqual(op.type, "IntInput") - self.assertEqual(op.outputs, []) - op_input = op.inputs[0].op - self.assertEqual(op_input.type, "Enter") - self.assertEqual(list(op_input.inputs), [x]) - self.assertEqual(op.graph, g) - # pylint: disable=protected-access - self.assertIsNotNone(op._get_control_flow_context()) - self.assertEqual(op._get_control_flow_context().name, - "myloop/while_context") - # pylint: enable=protected-access - - @test_util.run_v1_only("b/120545219") - def testWhileLoopWithInternalControlDep(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - - def body(i): - c = constant_op.constant(1.0, name="c") - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "myloop/myop"), [x], []) - with ops.control_dependencies([c]): - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return i - - control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") - - op = g.get_operation_by_name("myloop/myop") - self.assertIsNotNone(op) - c = g.get_operation_by_name("myloop/c") - self.assertIsNotNone(c) - # Internal control dep is preserved - self.assertEqual(op.control_inputs, [c]) - - @test_util.run_v1_only("b/120545219") - def testWhileLoopWithExternalControlDep(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - c = constant_op.constant(1.0) - - def body(i): - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "myloop/myop"), [x], []) - with ops.control_dependencies([c]): - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return i - - control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") - - op = g.get_operation_by_name("myloop/myop") - self.assertIsNotNone(op) - # External control dep is removed and replaced with internal control dep - self.assertNotEqual(op.control_inputs[0], c.op) - self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) - - - */ + @test_util.run_v1_only("b/120545219") + def testWhileLoop(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + + def body(i): + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "myloop/myop"), [x], []) + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return i + + control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") + + op = g.get_operation_by_name("myloop/myop") + self.assertIsNotNone(op) + self.assertEqual(op.name, "myloop/myop") + self.assertEqual(op.type, "IntInput") + self.assertEqual(op.outputs, []) + op_input = op.inputs[0].op + self.assertEqual(op_input.type, "Enter") + self.assertEqual(list(op_input.inputs), [x]) + self.assertEqual(op.graph, g) + # pylint: disable=protected-access + self.assertIsNotNone(op._get_control_flow_context()) + self.assertEqual(op._get_control_flow_context().name, + "myloop/while_context") + # pylint: enable=protected-access + + @test_util.run_v1_only("b/120545219") + def testWhileLoopWithInternalControlDep(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + + def body(i): + c = constant_op.constant(1.0, name="c") + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "myloop/myop"), [x], []) + with ops.control_dependencies([c]): + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return i + + control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") + + op = g.get_operation_by_name("myloop/myop") + self.assertIsNotNone(op) + c = g.get_operation_by_name("myloop/c") + self.assertIsNotNone(c) + # Internal control dep is preserved + self.assertEqual(op.control_inputs, [c]) + + @test_util.run_v1_only("b/120545219") + def testWhileLoopWithExternalControlDep(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + c = constant_op.constant(1.0) + + def body(i): + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "myloop/myop"), [x], []) + with ops.control_dependencies([c]): + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return i + + control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") + + op = g.get_operation_by_name("myloop/myop") + self.assertIsNotNone(op) + # External control dep is removed and replaced with internal control dep + self.assertNotEqual(op.control_inputs[0], c.op) + self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) + + + */ } } diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index e52c21ff..7dbf5e23 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -29,6 +29,11 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(expected, given); } + public void assertIsNotNone(object given) + { + Assert.IsNotNull(given); + } + protected PythonTest self { get => this; } } }