| @@ -8,6 +8,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class Graph | public partial class Graph | ||||
| { | { | ||||
| // Current control flow context. It could be either CondContext or WhileContext | |||||
| public IControlFlowContext _control_flow_context; | public IControlFlowContext _control_flow_context; | ||||
| // represents the nested with(...) statements | // represents the nested with(...) statements | ||||
| @@ -64,6 +64,9 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Add the subgraph defined by fn() to the graph. | |||||
| /// </summary> | |||||
| public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | ||||
| { | { | ||||
| // Add the subgraph defined by fn() to the graph. | // Add the subgraph defined by fn() to the graph. | ||||
| @@ -71,6 +74,22 @@ namespace Tensorflow.Operations | |||||
| var original_result = fn(); | var original_result = fn(); | ||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | ||||
| //TODO: port this chunck of missing code: | |||||
| /* | |||||
| if len(post_summaries) > len(pre_summaries): | |||||
| new_summaries = post_summaries[len(pre_summaries):] | |||||
| summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access | |||||
| summary_ref[:] = pre_summaries | |||||
| with ops.control_dependencies(new_summaries): | |||||
| if original_result is None: | |||||
| return no_op(), None | |||||
| else: | |||||
| original_result = nest.map_structure(array_ops.identity, | |||||
| original_result) | |||||
| */ | |||||
| if (original_result == null) | |||||
| return (original_result, null); | |||||
| switch (original_result) | switch (original_result) | ||||
| { | { | ||||
| case Operation[] results: | case Operation[] results: | ||||
| @@ -3,7 +3,24 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | |||||
| { | |||||
| /// <summary> | |||||
| /// The base class for control flow context. | |||||
| /// | |||||
| /// The usage pattern is a sequence of(Enter, Exit) followed by a final | |||||
| /// ExitResult. | |||||
| /// | |||||
| /// We maintain the following state for control flow contexts during graph | |||||
| /// construction: | |||||
| /// 1. graph has _control_flow_context: the current context used to | |||||
| /// construct new nodes.Changed by ctxt.Enter() and ctxt.Exit() | |||||
| /// 2. op has _control_flow_context: the context to which the op belongs. | |||||
| /// Set at the time the op is created.Immutable. | |||||
| /// 3. A ControlFlowContext has _outer_context: the context in which this | |||||
| /// context is created.Set at the time a context is created.Immutable. | |||||
| /// 4. A ControlFlowContext has _context_stack. | |||||
| /// Pushed and popped by ctxt.Enter() and ctxt.Exit() | |||||
| /// </summary> | |||||
| public abstract class ControlFlowContext : IPython, IControlFlowContext | public abstract class ControlFlowContext : IPython, IControlFlowContext | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| @@ -17,6 +34,8 @@ namespace Tensorflow.Operations | |||||
| _context_stack = new Stack<IControlFlowContext>(); | _context_stack = new Stack<IControlFlowContext>(); | ||||
| } | } | ||||
| public string name { get; set; } | |||||
| public void __init__() | public void __init__() | ||||
| { | { | ||||
| @@ -26,6 +45,13 @@ namespace Tensorflow.Operations | |||||
| { | { | ||||
| } | } | ||||
| public void __exit__() | |||||
| { | |||||
| } | |||||
| /// <summary> | |||||
| /// Enter this control flow context. | |||||
| /// </summary> | |||||
| public virtual void Enter() | public virtual void Enter() | ||||
| { | { | ||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| @@ -33,6 +59,16 @@ namespace Tensorflow.Operations | |||||
| graph._set_control_flow_context(this); | graph._set_control_flow_context(this); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Exit this control flow context. | |||||
| /// </summary> | |||||
| public virtual void Exit() | |||||
| { | |||||
| var graph = ops.get_default_graph(); | |||||
| var last_context = _context_stack.Pop(); | |||||
| graph._set_control_flow_context(last_context); | |||||
| } | |||||
| public void AddOp(Operation op) | public void AddOp(Operation op) | ||||
| { | { | ||||
| _AddOpInternal(op); | _AddOpInternal(op); | ||||
| @@ -56,17 +92,6 @@ namespace Tensorflow.Operations | |||||
| var internal_control_inputs = op.control_inputs; | var internal_control_inputs = op.control_inputs; | ||||
| } | } | ||||
| public void Exit() | |||||
| { | |||||
| var graph = ops.get_default_graph(); | |||||
| var last_context = _context_stack.Pop(); | |||||
| graph._set_control_flow_context(last_context); | |||||
| } | |||||
| public void __exit__() | |||||
| { | |||||
| } | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| } | } | ||||
| @@ -187,6 +187,48 @@ namespace Tensorflow | |||||
| return @switch(data, pred, name: name); | return @switch(data, pred, name: name); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`. | |||||
| /// | |||||
| /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and | |||||
| /// `false_fn` must have the same non-zero number and type of outputs. | |||||
| /// | |||||
| /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and | |||||
| /// `false_fn` will be executed regardless of which branch is selected at runtime. | |||||
| /// | |||||
| /// Although this behavior is consistent with the dataflow model of TensorFlow, | |||||
| /// it has frequently surprised users who expected a lazier semantics. | |||||
| /// Consider the following simple program: | |||||
| /// | |||||
| /// z = tf.multiply(a, b) | |||||
| /// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y)) | |||||
| /// | |||||
| /// If `x<y`, the `tf.add` operation will be executed and `tf.square` | |||||
| /// operation will not be executed.Since `z` is needed for at least one | |||||
| /// branch of the `cond`, the `tf.multiply` operation is always executed, | |||||
| /// unconditionally. | |||||
| /// | |||||
| /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the | |||||
| /// call to `cond`, and not at all during `Session.run()`). `cond` | |||||
| /// stitches together the graph fragments created during the `true_fn` and | |||||
| /// `false_fn` calls with some additional graph nodes to ensure that the right | |||||
| /// branch gets executed depending on the value of `pred`. | |||||
| /// | |||||
| /// `tf.cond` supports nested structures as implemented in | |||||
| /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the | |||||
| /// same(possibly nested) value structure of lists, tuples, and/or named tuples. | |||||
| /// Singleton lists and tuples form the only exceptions to this: when returned by | |||||
| /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. | |||||
| /// This behavior is disabled by passing `strict= True`. | |||||
| /// </summary> | |||||
| /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or | |||||
| /// `false_fn`.</param> | |||||
| /// <param name="true_fn">The callable to be performed if pred is true.</param> | |||||
| /// <param name="false_fn">The callable to be performed if pred is false.</param> | |||||
| /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param> | |||||
| /// <param name="name">Optional name prefix for the returned tensors.</param> | |||||
| /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the | |||||
| /// callables return a singleton list, the element is extracted from the list.</returns> | |||||
| public static Tensor cond(Tensor pred, | public static Tensor cond(Tensor pred, | ||||
| Func<ITensorOrOperation> true_fn = null, | Func<ITensorOrOperation> true_fn = null, | ||||
| Func<ITensorOrOperation> false_fn = null, | Func<ITensorOrOperation> false_fn = null, | ||||
| @@ -195,6 +237,37 @@ namespace Tensorflow | |||||
| { | { | ||||
| return with(ops.name_scope(name, "cond", new { pred }), delegate | return with(ops.name_scope(name, "cond", new { pred }), delegate | ||||
| { | { | ||||
| // TODO: here a chunk of original code is missing | |||||
| /* | |||||
| if fn1 is not None: | |||||
| if true_fn is not None: | |||||
| raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") | |||||
| true_fn = fn1 | |||||
| elif true_fn is None: | |||||
| raise TypeError("cond(): true_fn argument required") | |||||
| if fn2 is not None: | |||||
| if false_fn is not None: | |||||
| raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") | |||||
| false_fn = fn2 | |||||
| elif false_fn is None: | |||||
| raise TypeError("cond(): false_fn argument required") | |||||
| if not callable(true_fn): | |||||
| raise TypeError("true_fn must be callable.") | |||||
| if not callable(false_fn): | |||||
| raise TypeError("false_fn must be callable.") | |||||
| with ops.name_scope(name, "cond", [pred]): | |||||
| if context.executing_eagerly(): | |||||
| if pred: | |||||
| return _UnpackIfSingleton(true_fn()) | |||||
| return _UnpackIfSingleton(false_fn()) | |||||
| # Add the Switch to the graph. | |||||
| if isinstance(pred, bool): | |||||
| raise TypeError("pred must not be a Python bool") | |||||
| */ | |||||
| // Add the Switch to the graph. | // Add the Switch to the graph. | ||||
| var (p_2, p_1) = @switch(pred, pred); | var (p_2, p_1) = @switch(pred, pred); | ||||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | ||||
| @@ -207,15 +280,45 @@ namespace Tensorflow | |||||
| // Build the graph for the true branch in a new context. | // Build the graph for the true branch in a new context. | ||||
| var context_t = new CondContext(pred, pivot_1, branch: 1); | var context_t = new CondContext(pred, pivot_1, branch: 1); | ||||
| context_t.Enter(); | |||||
| var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||||
| context_t.Exit(); | |||||
| ITensorOrOperation orig_res_t; | |||||
| Tensor res_t; | |||||
| try | |||||
| { | |||||
| context_t.Enter(); | |||||
| (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||||
| } | |||||
| finally | |||||
| { | |||||
| context_t.Exit(); | |||||
| } | |||||
| // Build the graph for the false branch in a new context. | // Build the graph for the false branch in a new context. | ||||
| var context_f = new CondContext(pred, pivot_2, branch: 0); | var context_f = new CondContext(pred, pivot_2, branch: 0); | ||||
| context_f.Enter(); | |||||
| var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||||
| context_f.Exit(); | |||||
| ITensorOrOperation orig_res_f; | |||||
| Tensor res_f; | |||||
| try | |||||
| { | |||||
| context_f.Enter(); | |||||
| (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||||
| } | |||||
| finally | |||||
| { | |||||
| context_f.Exit(); | |||||
| } | |||||
| //TODO: missing original code | |||||
| //if not strict: | |||||
| // orig_res_t = _UnpackIfSingleton(orig_res_t) | |||||
| // orig_res_f = _UnpackIfSingleton(orig_res_f) | |||||
| /* | |||||
| # Check that the return values of the two branches have the same structure. | |||||
| try: | |||||
| nest.assert_same_structure(orig_res_t, orig_res_f) | |||||
| except TypeError as e: | |||||
| raise TypeError( | |||||
| "Incompatible return types of true_fn and false_fn: {}".format(e)) | |||||
| except ValueError as e: | |||||
| raise ValueError( | |||||
| "Incompatible return values of true_fn and false_fn: {}".format(e)) | |||||
| var res_t_flat = new Tensor[] { res_t }; | var res_t_flat = new Tensor[] { res_t }; | ||||
| var res_f_flat = new Tensor[] { res_f }; | var res_f_flat = new Tensor[] { res_f }; | ||||
| @@ -1,5 +1,6 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.ComponentModel; | using System.ComponentModel; | ||||
| using System.Linq; | using System.Linq; | ||||
| @@ -17,8 +18,8 @@ namespace Tensorflow | |||||
| Console.WriteLine(obj.ToString()); | Console.WriteLine(obj.ToString()); | ||||
| } | } | ||||
| protected int len(Array a) | |||||
| => a.Length; | |||||
| protected int len<T>(IEnumerable<T> a) | |||||
| => a.Count(); | |||||
| protected IEnumerable<int> range(int end) | protected IEnumerable<int> range(int end) | ||||
| { | { | ||||
| @@ -61,115 +61,150 @@ namespace TensorFlowNET.UnitTest | |||||
| self.assertEqual(op4.name, "myop_1_1"); | self.assertEqual(op4.name, "myop_1_1"); | ||||
| }); | }); | ||||
| } | } | ||||
| [Ignore("Something is not right, Switch gets not inserted correctly?")] | |||||
| [TestMethod] | |||||
| public void TestCond() | |||||
| { | |||||
| var graph = tf.Graph().as_default(); | |||||
| with<Graph>(graph, g => | |||||
| { | |||||
| var x = constant_op.constant(10); | |||||
| var true_fn = new Func<Tensor>(() => | |||||
| { | |||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); | |||||
| var new_ops = g._add_new_tf_operations(); | |||||
| self.assertEqual(len(new_ops), 1); | |||||
| return x; | |||||
| }); | |||||
| control_flow_ops.cond(x < 10, true_fn, () => x); | |||||
| var op = g.get_operation_by_name("cond/myop"); | |||||
| self.assertIsNotNone(op); | |||||
| self.assertEqual(op.name, "cond/myop"); | |||||
| self.assertEqual(op.type, "Identity"); | |||||
| //self.assertEqual(op.outputs, new object[0]); | |||||
| var op_input = op.inputs[0].op; | |||||
| self.assertEqual(op_input.type, "Switch"); | |||||
| self.assertEqual(op_input.inputs[0], x); | |||||
| self.assertEqual(op.graph, g); | |||||
| self.assertIsNotNone(op._get_control_flow_context()); | |||||
| // TODO: op._get_control_flow_context().name not implemented | |||||
| //self.assertEqual(op._get_control_flow_context().name, "cond/cond_text"); | |||||
| }); | |||||
| /* | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testCond(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| def true_fn(): | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "cond/myop"), [x], []) | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return x | |||||
| control_flow_ops.cond(x < 10, true_fn, lambda: x) | |||||
| op = g.get_operation_by_name("cond/myop") | |||||
| self.assertIsNotNone(op) | |||||
| self.assertEqual(op.name, "cond/myop") | |||||
| self.assertEqual(op.type, "IntInput") | |||||
| self.assertEqual(op.outputs, []) | |||||
| op_input = op.inputs[0].op | |||||
| self.assertEqual(op_input.type, "Switch") | |||||
| self.assertEqual(op_input.inputs[0], x) | |||||
| self.assertEqual(op.graph, g) | |||||
| # pylint: disable=protected-access | |||||
| self.assertIsNotNone(op._get_control_flow_context()) | |||||
| self.assertEqual(op._get_control_flow_context().name, | |||||
| "cond/cond_text") | |||||
| # pylint: enable=protected-access | |||||
| */ | |||||
| } | |||||
| /* | /* | ||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testCond(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| def true_fn(): | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "cond/myop"), [x], []) | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return x | |||||
| control_flow_ops.cond(x < 10, true_fn, lambda: x) | |||||
| op = g.get_operation_by_name("cond/myop") | |||||
| self.assertIsNotNone(op) | |||||
| self.assertEqual(op.name, "cond/myop") | |||||
| self.assertEqual(op.type, "IntInput") | |||||
| self.assertEqual(op.outputs, []) | |||||
| op_input = op.inputs[0].op | |||||
| self.assertEqual(op_input.type, "Switch") | |||||
| self.assertEqual(op_input.inputs[0], x) | |||||
| self.assertEqual(op.graph, g) | |||||
| # pylint: disable=protected-access | |||||
| self.assertIsNotNone(op._get_control_flow_context()) | |||||
| self.assertEqual(op._get_control_flow_context().name, | |||||
| "cond/cond_text") | |||||
| # pylint: enable=protected-access | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testWhileLoop(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| def body(i): | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return i | |||||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||||
| op = g.get_operation_by_name("myloop/myop") | |||||
| self.assertIsNotNone(op) | |||||
| self.assertEqual(op.name, "myloop/myop") | |||||
| self.assertEqual(op.type, "IntInput") | |||||
| self.assertEqual(op.outputs, []) | |||||
| op_input = op.inputs[0].op | |||||
| self.assertEqual(op_input.type, "Enter") | |||||
| self.assertEqual(list(op_input.inputs), [x]) | |||||
| self.assertEqual(op.graph, g) | |||||
| # pylint: disable=protected-access | |||||
| self.assertIsNotNone(op._get_control_flow_context()) | |||||
| self.assertEqual(op._get_control_flow_context().name, | |||||
| "myloop/while_context") | |||||
| # pylint: enable=protected-access | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testWhileLoopWithInternalControlDep(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| def body(i): | |||||
| c = constant_op.constant(1.0, name="c") | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||||
| with ops.control_dependencies([c]): | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return i | |||||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||||
| op = g.get_operation_by_name("myloop/myop") | |||||
| self.assertIsNotNone(op) | |||||
| c = g.get_operation_by_name("myloop/c") | |||||
| self.assertIsNotNone(c) | |||||
| # Internal control dep is preserved | |||||
| self.assertEqual(op.control_inputs, [c]) | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testWhileLoopWithExternalControlDep(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| c = constant_op.constant(1.0) | |||||
| def body(i): | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||||
| with ops.control_dependencies([c]): | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return i | |||||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||||
| op = g.get_operation_by_name("myloop/myop") | |||||
| self.assertIsNotNone(op) | |||||
| # External control dep is removed and replaced with internal control dep | |||||
| self.assertNotEqual(op.control_inputs[0], c.op) | |||||
| self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | |||||
| */ | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testWhileLoop(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| def body(i): | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return i | |||||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||||
| op = g.get_operation_by_name("myloop/myop") | |||||
| self.assertIsNotNone(op) | |||||
| self.assertEqual(op.name, "myloop/myop") | |||||
| self.assertEqual(op.type, "IntInput") | |||||
| self.assertEqual(op.outputs, []) | |||||
| op_input = op.inputs[0].op | |||||
| self.assertEqual(op_input.type, "Enter") | |||||
| self.assertEqual(list(op_input.inputs), [x]) | |||||
| self.assertEqual(op.graph, g) | |||||
| # pylint: disable=protected-access | |||||
| self.assertIsNotNone(op._get_control_flow_context()) | |||||
| self.assertEqual(op._get_control_flow_context().name, | |||||
| "myloop/while_context") | |||||
| # pylint: enable=protected-access | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testWhileLoopWithInternalControlDep(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| def body(i): | |||||
| c = constant_op.constant(1.0, name="c") | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||||
| with ops.control_dependencies([c]): | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return i | |||||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||||
| op = g.get_operation_by_name("myloop/myop") | |||||
| self.assertIsNotNone(op) | |||||
| c = g.get_operation_by_name("myloop/c") | |||||
| self.assertIsNotNone(c) | |||||
| # Internal control dep is preserved | |||||
| self.assertEqual(op.control_inputs, [c]) | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testWhileLoopWithExternalControlDep(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| c = constant_op.constant(1.0) | |||||
| def body(i): | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||||
| with ops.control_dependencies([c]): | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return i | |||||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||||
| op = g.get_operation_by_name("myloop/myop") | |||||
| self.assertIsNotNone(op) | |||||
| # External control dep is removed and replaced with internal control dep | |||||
| self.assertNotEqual(op.control_inputs[0], c.op) | |||||
| self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | |||||
| */ | |||||
| } | } | ||||
| } | } | ||||
| @@ -29,6 +29,11 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.AreEqual(expected, given); | Assert.AreEqual(expected, given); | ||||
| } | } | ||||
| public void assertIsNotNone(object given) | |||||
| { | |||||
| Assert.IsNotNull(given); | |||||
| } | |||||
| protected PythonTest self { get => this; } | protected PythonTest self { get => this; } | ||||
| } | } | ||||
| } | } | ||||