diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
index 0ac0becf..42cf1a17 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
@@ -8,6 +8,7 @@ namespace Tensorflow
{
public partial class Graph
{
+ // Current control flow context. It could be either CondContext or WhileContext
public IControlFlowContext _control_flow_context;
// represents the nested with(...) statements
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
index 0385341a..a0d84e89 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
@@ -64,6 +64,9 @@ namespace Tensorflow.Operations
}
}
+ ///
+ /// Add the subgraph defined by fn() to the graph.
+ ///
public (T, Tensor) BuildCondBranch(Func fn)
{
// Add the subgraph defined by fn() to the graph.
@@ -71,6 +74,22 @@ namespace Tensorflow.Operations
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
+ //TODO: port this chunck of missing code:
+ /*
+ if len(post_summaries) > len(pre_summaries):
+ new_summaries = post_summaries[len(pre_summaries):]
+ summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
+ summary_ref[:] = pre_summaries
+ with ops.control_dependencies(new_summaries):
+ if original_result is None:
+ return no_op(), None
+ else:
+ original_result = nest.map_structure(array_ops.identity,
+ original_result)
+ */
+ if (original_result == null)
+ return (original_result, null);
+
switch (original_result)
{
case Operation[] results:
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
index 8776f171..9d039d58 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
@@ -3,7 +3,24 @@ using System.Collections.Generic;
using System.Text;
namespace Tensorflow.Operations
-{
+{
+ ///
+ /// The base class for control flow context.
+ ///
+ /// The usage pattern is a sequence of(Enter, Exit) followed by a final
+ /// ExitResult.
+ ///
+ /// We maintain the following state for control flow contexts during graph
+ /// construction:
+ /// 1. graph has _control_flow_context: the current context used to
+ /// construct new nodes.Changed by ctxt.Enter() and ctxt.Exit()
+ /// 2. op has _control_flow_context: the context to which the op belongs.
+ /// Set at the time the op is created.Immutable.
+ /// 3. A ControlFlowContext has _outer_context: the context in which this
+ /// context is created.Set at the time a context is created.Immutable.
+ /// 4. A ControlFlowContext has _context_stack.
+ /// Pushed and popped by ctxt.Enter() and ctxt.Exit()
+ ///
public abstract class ControlFlowContext : IPython, IControlFlowContext
{
///
@@ -17,6 +34,8 @@ namespace Tensorflow.Operations
_context_stack = new Stack();
}
+ public string name { get; set; }
+
public void __init__()
{
@@ -26,6 +45,13 @@ namespace Tensorflow.Operations
{
}
+ public void __exit__()
+ {
+ }
+
+ ///
+ /// Enter this control flow context.
+ ///
public virtual void Enter()
{
var graph = ops.get_default_graph();
@@ -33,6 +59,16 @@ namespace Tensorflow.Operations
graph._set_control_flow_context(this);
}
+ ///
+ /// Exit this control flow context.
+ ///
+ public virtual void Exit()
+ {
+ var graph = ops.get_default_graph();
+ var last_context = _context_stack.Pop();
+ graph._set_control_flow_context(last_context);
+ }
+
public void AddOp(Operation op)
{
_AddOpInternal(op);
@@ -56,17 +92,6 @@ namespace Tensorflow.Operations
var internal_control_inputs = op.control_inputs;
}
- public void Exit()
- {
- var graph = ops.get_default_graph();
- var last_context = _context_stack.Pop();
- graph._set_control_flow_context(last_context);
- }
-
- public void __exit__()
- {
- }
-
public void Dispose()
{
}
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
index 0b7afced..5c27d24d 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -187,6 +187,48 @@ namespace Tensorflow
return @switch(data, pred, name: name);
}
+ ///
+ /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
+ ///
+ /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
+ /// `false_fn` must have the same non-zero number and type of outputs.
+ ///
+ /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and
+ /// `false_fn` will be executed regardless of which branch is selected at runtime.
+ ///
+ /// Although this behavior is consistent with the dataflow model of TensorFlow,
+ /// it has frequently surprised users who expected a lazier semantics.
+ /// Consider the following simple program:
+ ///
+ /// z = tf.multiply(a, b)
+ /// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y))
+ ///
+ /// If `x<y`, the `tf.add` operation will be executed and `tf.square`
+ /// operation will not be executed.Since `z` is needed for at least one
+ /// branch of the `cond`, the `tf.multiply` operation is always executed,
+ /// unconditionally.
+ ///
+ /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
+ /// call to `cond`, and not at all during `Session.run()`). `cond`
+ /// stitches together the graph fragments created during the `true_fn` and
+ /// `false_fn` calls with some additional graph nodes to ensure that the right
+ /// branch gets executed depending on the value of `pred`.
+ ///
+ /// `tf.cond` supports nested structures as implemented in
+ /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
+ /// same(possibly nested) value structure of lists, tuples, and/or named tuples.
+ /// Singleton lists and tuples form the only exceptions to this: when returned by
+ /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
+ /// This behavior is disabled by passing `strict= True`.
+ ///
+ /// A scalar determining whether to return the result of `true_fn` or
+ /// `false_fn`.
+ /// The callable to be performed if pred is true.
+ /// The callable to be performed if pred is false.
+ /// A boolean that enables/disables 'strict' mode; see above.
+ /// Optional name prefix for the returned tensors.
+ /// Tensors returned by the call to either `true_fn` or `false_fn`. If the
+ /// callables return a singleton list, the element is extracted from the list.
public static Tensor cond(Tensor pred,
Func true_fn = null,
Func false_fn = null,
@@ -195,6 +237,37 @@ namespace Tensorflow
{
return with(ops.name_scope(name, "cond", new { pred }), delegate
{
+ // TODO: here a chunk of original code is missing
+ /*
+ if fn1 is not None:
+ if true_fn is not None:
+ raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
+ true_fn = fn1
+ elif true_fn is None:
+ raise TypeError("cond(): true_fn argument required")
+ if fn2 is not None:
+ if false_fn is not None:
+ raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
+ false_fn = fn2
+ elif false_fn is None:
+ raise TypeError("cond(): false_fn argument required")
+
+ if not callable(true_fn):
+ raise TypeError("true_fn must be callable.")
+ if not callable(false_fn):
+ raise TypeError("false_fn must be callable.")
+
+ with ops.name_scope(name, "cond", [pred]):
+ if context.executing_eagerly():
+ if pred:
+ return _UnpackIfSingleton(true_fn())
+ return _UnpackIfSingleton(false_fn())
+
+ # Add the Switch to the graph.
+ if isinstance(pred, bool):
+ raise TypeError("pred must not be a Python bool")
+ */
+
// Add the Switch to the graph.
var (p_2, p_1) = @switch(pred, pred);
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
@@ -207,15 +280,45 @@ namespace Tensorflow
// Build the graph for the true branch in a new context.
var context_t = new CondContext(pred, pivot_1, branch: 1);
- context_t.Enter();
- var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
- context_t.Exit();
-
+ ITensorOrOperation orig_res_t;
+ Tensor res_t;
+ try
+ {
+ context_t.Enter();
+ (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
+ }
+ finally
+ {
+ context_t.Exit();
+ }
// Build the graph for the false branch in a new context.
var context_f = new CondContext(pred, pivot_2, branch: 0);
- context_f.Enter();
- var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
- context_f.Exit();
+ ITensorOrOperation orig_res_f;
+ Tensor res_f;
+ try
+ {
+ context_f.Enter();
+ (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
+ }
+ finally
+ {
+ context_f.Exit();
+ }
+
+ //TODO: missing original code
+ //if not strict:
+ // orig_res_t = _UnpackIfSingleton(orig_res_t)
+ // orig_res_f = _UnpackIfSingleton(orig_res_f)
+ /*
+ # Check that the return values of the two branches have the same structure.
+ try:
+ nest.assert_same_structure(orig_res_t, orig_res_f)
+ except TypeError as e:
+ raise TypeError(
+ "Incompatible return types of true_fn and false_fn: {}".format(e))
+ except ValueError as e:
+ raise ValueError(
+ "Incompatible return values of true_fn and false_fn: {}".format(e))
var res_t_flat = new Tensor[] { res_t };
var res_f_flat = new Tensor[] { res_f };
diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs
index ce859ec8..eaa57681 100644
--- a/src/TensorFlowNET.Core/Python.cs
+++ b/src/TensorFlowNET.Core/Python.cs
@@ -1,5 +1,6 @@
using NumSharp;
using System;
+using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
@@ -17,8 +18,8 @@ namespace Tensorflow
Console.WriteLine(obj.ToString());
}
- protected int len(Array a)
- => a.Length;
+ protected int len(IEnumerable a)
+ => a.Count();
protected IEnumerable range(int end)
{
diff --git a/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs
index 88e4da9d..78c9cd45 100644
--- a/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs
+++ b/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs
@@ -61,115 +61,150 @@ namespace TensorFlowNET.UnitTest
self.assertEqual(op4.name, "myop_1_1");
});
}
+
+ [Ignore("Something is not right, Switch gets not inserted correctly?")]
+ [TestMethod]
+ public void TestCond()
+ {
+ var graph = tf.Graph().as_default();
+ with(graph, g =>
+ {
+ var x = constant_op.constant(10);
+
+ var true_fn = new Func(() =>
+ {
+ var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
+ var new_ops = g._add_new_tf_operations();
+ self.assertEqual(len(new_ops), 1);
+ return x;
+ });
+
+ control_flow_ops.cond(x < 10, true_fn, () => x);
+
+ var op = g.get_operation_by_name("cond/myop");
+ self.assertIsNotNone(op);
+ self.assertEqual(op.name, "cond/myop");
+ self.assertEqual(op.type, "Identity");
+ //self.assertEqual(op.outputs, new object[0]);
+ var op_input = op.inputs[0].op;
+ self.assertEqual(op_input.type, "Switch");
+ self.assertEqual(op_input.inputs[0], x);
+ self.assertEqual(op.graph, g);
+ self.assertIsNotNone(op._get_control_flow_context());
+ // TODO: op._get_control_flow_context().name not implemented
+ //self.assertEqual(op._get_control_flow_context().name, "cond/cond_text");
+ });
+ /*
+ @test_util.run_v1_only("b/120545219")
+ def testCond(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+
+ def true_fn():
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "cond/myop"), [x], [])
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return x
+
+ control_flow_ops.cond(x < 10, true_fn, lambda: x)
+
+ op = g.get_operation_by_name("cond/myop")
+ self.assertIsNotNone(op)
+ self.assertEqual(op.name, "cond/myop")
+ self.assertEqual(op.type, "IntInput")
+ self.assertEqual(op.outputs, [])
+ op_input = op.inputs[0].op
+ self.assertEqual(op_input.type, "Switch")
+ self.assertEqual(op_input.inputs[0], x)
+ self.assertEqual(op.graph, g)
+ # pylint: disable=protected-access
+ self.assertIsNotNone(op._get_control_flow_context())
+ self.assertEqual(op._get_control_flow_context().name,
+ "cond/cond_text")
+ # pylint: enable=protected-access
+ */
+ }
/*
- @test_util.run_v1_only("b/120545219")
- def testCond(self):
- g = ops.Graph()
- with g.as_default():
- x = test_ops.int_output()
-
- def true_fn():
- ops._create_c_op(ops.get_default_graph(),
- ops._NodeDef("IntInput", "cond/myop"), [x], [])
- new_ops = g._add_new_tf_operations()
- self.assertEqual(len(new_ops), 1)
- return x
-
- control_flow_ops.cond(x < 10, true_fn, lambda: x)
-
- op = g.get_operation_by_name("cond/myop")
- self.assertIsNotNone(op)
- self.assertEqual(op.name, "cond/myop")
- self.assertEqual(op.type, "IntInput")
- self.assertEqual(op.outputs, [])
- op_input = op.inputs[0].op
- self.assertEqual(op_input.type, "Switch")
- self.assertEqual(op_input.inputs[0], x)
- self.assertEqual(op.graph, g)
- # pylint: disable=protected-access
- self.assertIsNotNone(op._get_control_flow_context())
- self.assertEqual(op._get_control_flow_context().name,
- "cond/cond_text")
- # pylint: enable=protected-access
-
- @test_util.run_v1_only("b/120545219")
- def testWhileLoop(self):
- g = ops.Graph()
- with g.as_default():
- x = test_ops.int_output()
-
- def body(i):
- ops._create_c_op(ops.get_default_graph(),
- ops._NodeDef("IntInput", "myloop/myop"), [x], [])
- new_ops = g._add_new_tf_operations()
- self.assertEqual(len(new_ops), 1)
- return i
-
- control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
-
- op = g.get_operation_by_name("myloop/myop")
- self.assertIsNotNone(op)
- self.assertEqual(op.name, "myloop/myop")
- self.assertEqual(op.type, "IntInput")
- self.assertEqual(op.outputs, [])
- op_input = op.inputs[0].op
- self.assertEqual(op_input.type, "Enter")
- self.assertEqual(list(op_input.inputs), [x])
- self.assertEqual(op.graph, g)
- # pylint: disable=protected-access
- self.assertIsNotNone(op._get_control_flow_context())
- self.assertEqual(op._get_control_flow_context().name,
- "myloop/while_context")
- # pylint: enable=protected-access
-
- @test_util.run_v1_only("b/120545219")
- def testWhileLoopWithInternalControlDep(self):
- g = ops.Graph()
- with g.as_default():
- x = test_ops.int_output()
-
- def body(i):
- c = constant_op.constant(1.0, name="c")
- ops._create_c_op(ops.get_default_graph(),
- ops._NodeDef("IntInput", "myloop/myop"), [x], [])
- with ops.control_dependencies([c]):
- new_ops = g._add_new_tf_operations()
- self.assertEqual(len(new_ops), 1)
- return i
-
- control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
-
- op = g.get_operation_by_name("myloop/myop")
- self.assertIsNotNone(op)
- c = g.get_operation_by_name("myloop/c")
- self.assertIsNotNone(c)
- # Internal control dep is preserved
- self.assertEqual(op.control_inputs, [c])
-
- @test_util.run_v1_only("b/120545219")
- def testWhileLoopWithExternalControlDep(self):
- g = ops.Graph()
- with g.as_default():
- x = test_ops.int_output()
- c = constant_op.constant(1.0)
-
- def body(i):
- ops._create_c_op(ops.get_default_graph(),
- ops._NodeDef("IntInput", "myloop/myop"), [x], [])
- with ops.control_dependencies([c]):
- new_ops = g._add_new_tf_operations()
- self.assertEqual(len(new_ops), 1)
- return i
-
- control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
-
- op = g.get_operation_by_name("myloop/myop")
- self.assertIsNotNone(op)
- # External control dep is removed and replaced with internal control dep
- self.assertNotEqual(op.control_inputs[0], c.op)
- self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
-
-
- */
+ @test_util.run_v1_only("b/120545219")
+ def testWhileLoop(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+
+ def body(i):
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return i
+
+ control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
+
+ op = g.get_operation_by_name("myloop/myop")
+ self.assertIsNotNone(op)
+ self.assertEqual(op.name, "myloop/myop")
+ self.assertEqual(op.type, "IntInput")
+ self.assertEqual(op.outputs, [])
+ op_input = op.inputs[0].op
+ self.assertEqual(op_input.type, "Enter")
+ self.assertEqual(list(op_input.inputs), [x])
+ self.assertEqual(op.graph, g)
+ # pylint: disable=protected-access
+ self.assertIsNotNone(op._get_control_flow_context())
+ self.assertEqual(op._get_control_flow_context().name,
+ "myloop/while_context")
+ # pylint: enable=protected-access
+
+ @test_util.run_v1_only("b/120545219")
+ def testWhileLoopWithInternalControlDep(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+
+ def body(i):
+ c = constant_op.constant(1.0, name="c")
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ with ops.control_dependencies([c]):
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return i
+
+ control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
+
+ op = g.get_operation_by_name("myloop/myop")
+ self.assertIsNotNone(op)
+ c = g.get_operation_by_name("myloop/c")
+ self.assertIsNotNone(c)
+ # Internal control dep is preserved
+ self.assertEqual(op.control_inputs, [c])
+
+ @test_util.run_v1_only("b/120545219")
+ def testWhileLoopWithExternalControlDep(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+ c = constant_op.constant(1.0)
+
+ def body(i):
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ with ops.control_dependencies([c]):
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return i
+
+ control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
+
+ op = g.get_operation_by_name("myloop/myop")
+ self.assertIsNotNone(op)
+ # External control dep is removed and replaced with internal control dep
+ self.assertNotEqual(op.control_inputs[0], c.op)
+ self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
+
+
+ */
}
}
diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs
index e52c21ff..7dbf5e23 100644
--- a/test/TensorFlowNET.UnitTest/PythonTest.cs
+++ b/test/TensorFlowNET.UnitTest/PythonTest.cs
@@ -29,6 +29,11 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(expected, given);
}
+ public void assertIsNotNone(object given)
+ {
+ Assert.IsNotNull(given);
+ }
+
protected PythonTest self { get => this; }
}
}