| @@ -8,6 +8,7 @@ namespace Tensorflow | |||
| { | |||
| public partial class Graph | |||
| { | |||
| // Current control flow context. It could be either CondContext or WhileContext | |||
| public IControlFlowContext _control_flow_context; | |||
| // represents the nested with(...) statements | |||
| @@ -64,6 +64,9 @@ namespace Tensorflow.Operations | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Add the subgraph defined by fn() to the graph. | |||
| /// </summary> | |||
| public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | |||
| { | |||
| // Add the subgraph defined by fn() to the graph. | |||
| @@ -71,6 +74,22 @@ namespace Tensorflow.Operations | |||
| var original_result = fn(); | |||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
| //TODO: port this chunck of missing code: | |||
| /* | |||
| if len(post_summaries) > len(pre_summaries): | |||
| new_summaries = post_summaries[len(pre_summaries):] | |||
| summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access | |||
| summary_ref[:] = pre_summaries | |||
| with ops.control_dependencies(new_summaries): | |||
| if original_result is None: | |||
| return no_op(), None | |||
| else: | |||
| original_result = nest.map_structure(array_ops.identity, | |||
| original_result) | |||
| */ | |||
| if (original_result == null) | |||
| return (original_result, null); | |||
| switch (original_result) | |||
| { | |||
| case Operation[] results: | |||
| @@ -3,7 +3,24 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| { | |||
| /// <summary> | |||
| /// The base class for control flow context. | |||
| /// | |||
| /// The usage pattern is a sequence of(Enter, Exit) followed by a final | |||
| /// ExitResult. | |||
| /// | |||
| /// We maintain the following state for control flow contexts during graph | |||
| /// construction: | |||
| /// 1. graph has _control_flow_context: the current context used to | |||
| /// construct new nodes.Changed by ctxt.Enter() and ctxt.Exit() | |||
| /// 2. op has _control_flow_context: the context to which the op belongs. | |||
| /// Set at the time the op is created.Immutable. | |||
| /// 3. A ControlFlowContext has _outer_context: the context in which this | |||
| /// context is created.Set at the time a context is created.Immutable. | |||
| /// 4. A ControlFlowContext has _context_stack. | |||
| /// Pushed and popped by ctxt.Enter() and ctxt.Exit() | |||
| /// </summary> | |||
| public abstract class ControlFlowContext : IPython, IControlFlowContext | |||
| { | |||
| /// <summary> | |||
| @@ -17,6 +34,8 @@ namespace Tensorflow.Operations | |||
| _context_stack = new Stack<IControlFlowContext>(); | |||
| } | |||
| public string name { get; set; } | |||
| public void __init__() | |||
| { | |||
| @@ -26,6 +45,13 @@ namespace Tensorflow.Operations | |||
| { | |||
| } | |||
| public void __exit__() | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// Enter this control flow context. | |||
| /// </summary> | |||
| public virtual void Enter() | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| @@ -33,6 +59,16 @@ namespace Tensorflow.Operations | |||
| graph._set_control_flow_context(this); | |||
| } | |||
| /// <summary> | |||
| /// Exit this control flow context. | |||
| /// </summary> | |||
| public virtual void Exit() | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| var last_context = _context_stack.Pop(); | |||
| graph._set_control_flow_context(last_context); | |||
| } | |||
| public void AddOp(Operation op) | |||
| { | |||
| _AddOpInternal(op); | |||
| @@ -56,17 +92,6 @@ namespace Tensorflow.Operations | |||
| var internal_control_inputs = op.control_inputs; | |||
| } | |||
| public void Exit() | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| var last_context = _context_stack.Pop(); | |||
| graph._set_control_flow_context(last_context); | |||
| } | |||
| public void __exit__() | |||
| { | |||
| } | |||
| public void Dispose() | |||
| { | |||
| } | |||
| @@ -187,6 +187,48 @@ namespace Tensorflow | |||
| return @switch(data, pred, name: name); | |||
| } | |||
| /// <summary> | |||
| /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`. | |||
| /// | |||
| /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and | |||
| /// `false_fn` must have the same non-zero number and type of outputs. | |||
| /// | |||
| /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and | |||
| /// `false_fn` will be executed regardless of which branch is selected at runtime. | |||
| /// | |||
| /// Although this behavior is consistent with the dataflow model of TensorFlow, | |||
| /// it has frequently surprised users who expected a lazier semantics. | |||
| /// Consider the following simple program: | |||
| /// | |||
| /// z = tf.multiply(a, b) | |||
| /// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y)) | |||
| /// | |||
| /// If `x<y`, the `tf.add` operation will be executed and `tf.square` | |||
| /// operation will not be executed.Since `z` is needed for at least one | |||
| /// branch of the `cond`, the `tf.multiply` operation is always executed, | |||
| /// unconditionally. | |||
| /// | |||
| /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the | |||
| /// call to `cond`, and not at all during `Session.run()`). `cond` | |||
| /// stitches together the graph fragments created during the `true_fn` and | |||
| /// `false_fn` calls with some additional graph nodes to ensure that the right | |||
| /// branch gets executed depending on the value of `pred`. | |||
| /// | |||
| /// `tf.cond` supports nested structures as implemented in | |||
| /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the | |||
| /// same(possibly nested) value structure of lists, tuples, and/or named tuples. | |||
| /// Singleton lists and tuples form the only exceptions to this: when returned by | |||
| /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. | |||
| /// This behavior is disabled by passing `strict= True`. | |||
| /// </summary> | |||
| /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or | |||
| /// `false_fn`.</param> | |||
| /// <param name="true_fn">The callable to be performed if pred is true.</param> | |||
| /// <param name="false_fn">The callable to be performed if pred is false.</param> | |||
| /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param> | |||
| /// <param name="name">Optional name prefix for the returned tensors.</param> | |||
| /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the | |||
| /// callables return a singleton list, the element is extracted from the list.</returns> | |||
| public static Tensor cond(Tensor pred, | |||
| Func<ITensorOrOperation> true_fn = null, | |||
| Func<ITensorOrOperation> false_fn = null, | |||
| @@ -195,6 +237,37 @@ namespace Tensorflow | |||
| { | |||
| return with(ops.name_scope(name, "cond", new { pred }), delegate | |||
| { | |||
| // TODO: here a chunk of original code is missing | |||
| /* | |||
| if fn1 is not None: | |||
| if true_fn is not None: | |||
| raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") | |||
| true_fn = fn1 | |||
| elif true_fn is None: | |||
| raise TypeError("cond(): true_fn argument required") | |||
| if fn2 is not None: | |||
| if false_fn is not None: | |||
| raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") | |||
| false_fn = fn2 | |||
| elif false_fn is None: | |||
| raise TypeError("cond(): false_fn argument required") | |||
| if not callable(true_fn): | |||
| raise TypeError("true_fn must be callable.") | |||
| if not callable(false_fn): | |||
| raise TypeError("false_fn must be callable.") | |||
| with ops.name_scope(name, "cond", [pred]): | |||
| if context.executing_eagerly(): | |||
| if pred: | |||
| return _UnpackIfSingleton(true_fn()) | |||
| return _UnpackIfSingleton(false_fn()) | |||
| # Add the Switch to the graph. | |||
| if isinstance(pred, bool): | |||
| raise TypeError("pred must not be a Python bool") | |||
| */ | |||
| // Add the Switch to the graph. | |||
| var (p_2, p_1) = @switch(pred, pred); | |||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||
| @@ -207,15 +280,45 @@ namespace Tensorflow | |||
| // Build the graph for the true branch in a new context. | |||
| var context_t = new CondContext(pred, pivot_1, branch: 1); | |||
| context_t.Enter(); | |||
| var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||
| context_t.Exit(); | |||
| ITensorOrOperation orig_res_t; | |||
| Tensor res_t; | |||
| try | |||
| { | |||
| context_t.Enter(); | |||
| (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||
| } | |||
| finally | |||
| { | |||
| context_t.Exit(); | |||
| } | |||
| // Build the graph for the false branch in a new context. | |||
| var context_f = new CondContext(pred, pivot_2, branch: 0); | |||
| context_f.Enter(); | |||
| var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||
| context_f.Exit(); | |||
| ITensorOrOperation orig_res_f; | |||
| Tensor res_f; | |||
| try | |||
| { | |||
| context_f.Enter(); | |||
| (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||
| } | |||
| finally | |||
| { | |||
| context_f.Exit(); | |||
| } | |||
| //TODO: missing original code | |||
| //if not strict: | |||
| // orig_res_t = _UnpackIfSingleton(orig_res_t) | |||
| // orig_res_f = _UnpackIfSingleton(orig_res_f) | |||
| /* | |||
| # Check that the return values of the two branches have the same structure. | |||
| try: | |||
| nest.assert_same_structure(orig_res_t, orig_res_f) | |||
| except TypeError as e: | |||
| raise TypeError( | |||
| "Incompatible return types of true_fn and false_fn: {}".format(e)) | |||
| except ValueError as e: | |||
| raise ValueError( | |||
| "Incompatible return values of true_fn and false_fn: {}".format(e)) | |||
| var res_t_flat = new Tensor[] { res_t }; | |||
| var res_f_flat = new Tensor[] { res_f }; | |||
| @@ -1,5 +1,6 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.ComponentModel; | |||
| using System.Linq; | |||
| @@ -17,8 +18,8 @@ namespace Tensorflow | |||
| Console.WriteLine(obj.ToString()); | |||
| } | |||
| protected int len(Array a) | |||
| => a.Length; | |||
| protected int len<T>(IEnumerable<T> a) | |||
| => a.Count(); | |||
| protected IEnumerable<int> range(int end) | |||
| { | |||
| @@ -61,115 +61,150 @@ namespace TensorFlowNET.UnitTest | |||
| self.assertEqual(op4.name, "myop_1_1"); | |||
| }); | |||
| } | |||
| [Ignore("Something is not right, Switch gets not inserted correctly?")] | |||
| [TestMethod] | |||
| public void TestCond() | |||
| { | |||
| var graph = tf.Graph().as_default(); | |||
| with<Graph>(graph, g => | |||
| { | |||
| var x = constant_op.constant(10); | |||
| var true_fn = new Func<Tensor>(() => | |||
| { | |||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); | |||
| var new_ops = g._add_new_tf_operations(); | |||
| self.assertEqual(len(new_ops), 1); | |||
| return x; | |||
| }); | |||
| control_flow_ops.cond(x < 10, true_fn, () => x); | |||
| var op = g.get_operation_by_name("cond/myop"); | |||
| self.assertIsNotNone(op); | |||
| self.assertEqual(op.name, "cond/myop"); | |||
| self.assertEqual(op.type, "Identity"); | |||
| //self.assertEqual(op.outputs, new object[0]); | |||
| var op_input = op.inputs[0].op; | |||
| self.assertEqual(op_input.type, "Switch"); | |||
| self.assertEqual(op_input.inputs[0], x); | |||
| self.assertEqual(op.graph, g); | |||
| self.assertIsNotNone(op._get_control_flow_context()); | |||
| // TODO: op._get_control_flow_context().name not implemented | |||
| //self.assertEqual(op._get_control_flow_context().name, "cond/cond_text"); | |||
| }); | |||
| /* | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testCond(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| def true_fn(): | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "cond/myop"), [x], []) | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return x | |||
| control_flow_ops.cond(x < 10, true_fn, lambda: x) | |||
| op = g.get_operation_by_name("cond/myop") | |||
| self.assertIsNotNone(op) | |||
| self.assertEqual(op.name, "cond/myop") | |||
| self.assertEqual(op.type, "IntInput") | |||
| self.assertEqual(op.outputs, []) | |||
| op_input = op.inputs[0].op | |||
| self.assertEqual(op_input.type, "Switch") | |||
| self.assertEqual(op_input.inputs[0], x) | |||
| self.assertEqual(op.graph, g) | |||
| # pylint: disable=protected-access | |||
| self.assertIsNotNone(op._get_control_flow_context()) | |||
| self.assertEqual(op._get_control_flow_context().name, | |||
| "cond/cond_text") | |||
| # pylint: enable=protected-access | |||
| */ | |||
| } | |||
| /* | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testCond(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| def true_fn(): | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "cond/myop"), [x], []) | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return x | |||
| control_flow_ops.cond(x < 10, true_fn, lambda: x) | |||
| op = g.get_operation_by_name("cond/myop") | |||
| self.assertIsNotNone(op) | |||
| self.assertEqual(op.name, "cond/myop") | |||
| self.assertEqual(op.type, "IntInput") | |||
| self.assertEqual(op.outputs, []) | |||
| op_input = op.inputs[0].op | |||
| self.assertEqual(op_input.type, "Switch") | |||
| self.assertEqual(op_input.inputs[0], x) | |||
| self.assertEqual(op.graph, g) | |||
| # pylint: disable=protected-access | |||
| self.assertIsNotNone(op._get_control_flow_context()) | |||
| self.assertEqual(op._get_control_flow_context().name, | |||
| "cond/cond_text") | |||
| # pylint: enable=protected-access | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testWhileLoop(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| def body(i): | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return i | |||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
| op = g.get_operation_by_name("myloop/myop") | |||
| self.assertIsNotNone(op) | |||
| self.assertEqual(op.name, "myloop/myop") | |||
| self.assertEqual(op.type, "IntInput") | |||
| self.assertEqual(op.outputs, []) | |||
| op_input = op.inputs[0].op | |||
| self.assertEqual(op_input.type, "Enter") | |||
| self.assertEqual(list(op_input.inputs), [x]) | |||
| self.assertEqual(op.graph, g) | |||
| # pylint: disable=protected-access | |||
| self.assertIsNotNone(op._get_control_flow_context()) | |||
| self.assertEqual(op._get_control_flow_context().name, | |||
| "myloop/while_context") | |||
| # pylint: enable=protected-access | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testWhileLoopWithInternalControlDep(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| def body(i): | |||
| c = constant_op.constant(1.0, name="c") | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
| with ops.control_dependencies([c]): | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return i | |||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
| op = g.get_operation_by_name("myloop/myop") | |||
| self.assertIsNotNone(op) | |||
| c = g.get_operation_by_name("myloop/c") | |||
| self.assertIsNotNone(c) | |||
| # Internal control dep is preserved | |||
| self.assertEqual(op.control_inputs, [c]) | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testWhileLoopWithExternalControlDep(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| c = constant_op.constant(1.0) | |||
| def body(i): | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
| with ops.control_dependencies([c]): | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return i | |||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
| op = g.get_operation_by_name("myloop/myop") | |||
| self.assertIsNotNone(op) | |||
| # External control dep is removed and replaced with internal control dep | |||
| self.assertNotEqual(op.control_inputs[0], c.op) | |||
| self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | |||
| */ | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testWhileLoop(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| def body(i): | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return i | |||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
| op = g.get_operation_by_name("myloop/myop") | |||
| self.assertIsNotNone(op) | |||
| self.assertEqual(op.name, "myloop/myop") | |||
| self.assertEqual(op.type, "IntInput") | |||
| self.assertEqual(op.outputs, []) | |||
| op_input = op.inputs[0].op | |||
| self.assertEqual(op_input.type, "Enter") | |||
| self.assertEqual(list(op_input.inputs), [x]) | |||
| self.assertEqual(op.graph, g) | |||
| # pylint: disable=protected-access | |||
| self.assertIsNotNone(op._get_control_flow_context()) | |||
| self.assertEqual(op._get_control_flow_context().name, | |||
| "myloop/while_context") | |||
| # pylint: enable=protected-access | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testWhileLoopWithInternalControlDep(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| def body(i): | |||
| c = constant_op.constant(1.0, name="c") | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
| with ops.control_dependencies([c]): | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return i | |||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
| op = g.get_operation_by_name("myloop/myop") | |||
| self.assertIsNotNone(op) | |||
| c = g.get_operation_by_name("myloop/c") | |||
| self.assertIsNotNone(c) | |||
| # Internal control dep is preserved | |||
| self.assertEqual(op.control_inputs, [c]) | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testWhileLoopWithExternalControlDep(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| c = constant_op.constant(1.0) | |||
| def body(i): | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
| with ops.control_dependencies([c]): | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return i | |||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
| op = g.get_operation_by_name("myloop/myop") | |||
| self.assertIsNotNone(op) | |||
| # External control dep is removed and replaced with internal control dep | |||
| self.assertNotEqual(op.control_inputs[0], c.op) | |||
| self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | |||
| */ | |||
| } | |||
| } | |||
| @@ -29,6 +29,11 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.AreEqual(expected, given); | |||
| } | |||
| public void assertIsNotNone(object given) | |||
| { | |||
| Assert.IsNotNull(given); | |||
| } | |||
| protected PythonTest self { get => this; } | |||
| } | |||
| } | |||