From 615b54ddc6a521c22649f8ba692dc76195d896a9 Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Wed, 10 Apr 2019 10:51:13 +0200 Subject: [PATCH] more control flow fixes but CreateOpFromTfOperationTest.TestCond still fails --- .../Operations/ControlFlows/CondContext.cs | 9 ++- .../Operations/Operation.Input.cs | 6 +- .../Operations/control_flow_ops.py.cs | 55 ++++++++++++-- .../CreateOpFromTfOperationTest.cs | 73 +++++++++++++++---- test/TensorFlowNET.UnitTest/PythonTest.cs | 2 +- 5 files changed, 118 insertions(+), 27 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index a0d84e89..f32a6fa7 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow.Operations @@ -92,13 +93,15 @@ namespace Tensorflow.Operations switch (original_result) { + case Tensor result: + return (original_result, _BuildCondTensor(new[] { result.op })); case Operation[] results: return (original_result, _BuildCondTensor(results)); - case Tensor tensor: - return (original_result, tensor); case float[] fv: + { var result = ops.convert_to_tensor(fv[0]); return (original_result, result ); + } default: return (original_result, null); } @@ -114,7 +117,7 @@ namespace Tensorflow.Operations switch (original_result) { case Tensor[] results: - return (original_result, results); + return (original_result, new Tensor[] { _BuildCondTensor(results.Select(t=>t.op).ToArray())}); case Operation[] results: return (original_result, new Tensor[] { _BuildCondTensor (results) }); case float[] fv: diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 9ef89271..26c9c08c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -27,9 +27,9 @@ namespace Tensorflow for (int i = 0; i < NumInputs; i++) { - var tf_outpus = Input(i); - var op = new Operation(tf_outpus.oper); - retval[i] = op.outputs[tf_outpus.index]; + var tf_outputs = Input(i); + var op = new Operation(tf_outputs.oper); + retval[i] = op.outputs[tf_outputs.index]; } _inputs = new InputList(retval); diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 40b9a461..e0f38c95 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -142,10 +142,29 @@ namespace Tensorflow return tpl.ToArray(); }); - } - + } + + /// + /// Produces the content of `output_tensor` only after `dependencies`. + /// + /// In some cases, a user may want the output of an operation to be + /// consumed externally only after some other dependencies have run + /// first.This function ensures returns `output_tensor`, but only after all + /// operations in `dependencies` have run.Note that this means that there is + /// no guarantee that `output_tensor` will be evaluated after any `dependencies` + /// have run. + /// + /// See also `tf.tuple` and `tf.group`. + /// + /// Iterable of operations to run before this op finishes. + /// A `Tensor` or `IndexedSlices` that will be returned. + /// (Optional) A name for this operation. + /// Same as `output_tensor`. public static Tensor with_dependencies(Operation[] dependencies, Tensor output_tensor, string name = null) { + //TODO: missing original code + //if context.executing_eagerly(): + // return output_tensor var values = new List(); values.AddRange(dependencies); values.Add(output_tensor); @@ -153,12 +172,15 @@ namespace Tensorflow return with(ops.name_scope(name, "control_dependency", values), scope => { name = scope; - - return with(ops.control_dependencies(dependencies), ctl => + // TODO: missing original code + //with ops.colocate_with(output_tensor): { - output_tensor = ops.convert_to_tensor_or_composite(output_tensor); - return _Identity(output_tensor, name: name); - }); + return with(ops.control_dependencies(dependencies), ctl => + { + output_tensor = ops.convert_to_tensor_or_composite(output_tensor); + return _Identity(output_tensor, name: name); + }); + } }); } @@ -393,8 +415,27 @@ namespace Tensorflow return tensors_or_flows; } + /// + /// Returns the value of an available element of `inputs`. + /// + /// This op tests each of the tensors in `inputs` in turn to determine if any of + /// them is available.If it finds an available tensor, it returns it and its + /// index in `inputs`. + /// + /// It is an error if more than one tensor in `inputs` is available.If no tensor + /// in `inputs` is available, the returned tensor and index are not set. + /// + /// This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of + /// `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices + /// before merging. + /// + /// inputs: The input tensors, at most one of which is available. + /// A name for this operation (optional). + /// public static Tensor merge(Tensor[] inputs, string name = null) { + if (inputs.Any(x => x == null)) + throw new ValueError($"At least one of the merge inputs is null: {inputs}"); return with(ops.name_scope(name, "Merge", inputs), scope => { name = scope; diff --git a/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs index 78c9cd45..967dbfaa 100644 --- a/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs @@ -1,8 +1,10 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow; +using Tensorflow.Operations; namespace TensorFlowNET.UnitTest { @@ -19,21 +21,21 @@ namespace TensorFlowNET.UnitTest [TestClass] public class CreateOpFromTfOperationTest : PythonTest { - + [TestMethod] public void TestShape() { var graph = tf.Graph().as_default(); with(graph, g => { - var x = constant_op.constant(new [,] { {1, 2, 3}, {4, 5, 6}}); - var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); + var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); + var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); var op = g._create_op_from_tf_operation(c_op); Assert.AreEqual("myop", op.name); Assert.AreEqual("Identity", op.type); Assert.AreEqual(1, len(op.outputs)); - assertItemsEqual(new []{2, 3}, op.outputs[0].shape); + assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); }); } @@ -47,7 +49,7 @@ namespace TensorFlowNET.UnitTest //var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]); //var op = g._create_op_from_tf_operation(c_op); //var op2 = g._create_op_from_tf_operation(c_op2); - var op = constant_op.constant(0, name:"myop").op; + var op = constant_op.constant(0, name: "myop").op; var op2 = constant_op.constant(0, name: "myop_1").op; // Create ops with same names as op1 and op2. We expect the new names to be @@ -62,7 +64,7 @@ namespace TensorFlowNET.UnitTest }); } - [Ignore("Something is not right, Switch gets not inserted correctly?")] + [Ignore("Switch op gets not inserted correctly in the graph")] [TestMethod] public void TestCond() { @@ -91,8 +93,7 @@ namespace TensorFlowNET.UnitTest self.assertEqual(op_input.inputs[0], x); self.assertEqual(op.graph, g); self.assertIsNotNone(op._get_control_flow_context()); - // TODO: op._get_control_flow_context().name not implemented - //self.assertEqual(op._get_control_flow_context().name, "cond/cond_text"); + self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text"); }); /* @test_util.run_v1_only("b/120545219") @@ -126,7 +127,39 @@ namespace TensorFlowNET.UnitTest # pylint: enable=protected-access */ } - /* + + [Ignore("Todo: Port")] + [TestMethod] + public void TestWhileLoop() + { + var graph = tf.Graph().as_default(); + Operation x=null; + with(graph, g => + { + x = constant_op.constant(42); + var body = new Func(i => + { + ops._create_c_op(ops.get_default_graph(), ops._NodeDef("Identity", "myloop/myop"), new[] {x}, + new Operation[0]); + var new_ops = g._add_new_tf_operations(); + self.assertEqual(len(new_ops), 1); + return i; + }); + // TODO: port control_flow_ops.while_loop + //control_flow_ops.while_loop( i => i < 10, body, new int[]{0}, name = "myloop"); + }); + var op = graph.get_operation_by_name("myloop/myop"); + self.assertIsNotNone(op); + self.assertEqual(op.name, "myloop/myop"); + self.assertEqual(op.type, "Identity"); + self.assertEqual(op.outputs.Length, 0); + var op_input = op.inputs[0].op; + self.assertEqual(op_input.type, "Enter"); + self.assertItemsEqual(op_input.inputs.OfType().ToArray(), new[] {x}); + self.assertEqual(op.graph, graph); + self.assertIsNotNone(op._get_control_flow_context()); + self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).name, "myloop/while_context"); + /* @test_util.run_v1_only("b/120545219") def testWhileLoop(self): g = ops.Graph() @@ -156,8 +189,15 @@ namespace TensorFlowNET.UnitTest self.assertEqual(op._get_control_flow_context().name, "myloop/while_context") # pylint: enable=protected-access + */ + } - @test_util.run_v1_only("b/120545219") + [Ignore("Todo: Port")] + [TestMethod] + public void TestWhileLoopWithInternalControlDep() + { + /* +@test_util.run_v1_only("b/120545219") def testWhileLoopWithInternalControlDep(self): g = ops.Graph() with g.as_default(): @@ -180,7 +220,14 @@ namespace TensorFlowNET.UnitTest self.assertIsNotNone(c) # Internal control dep is preserved self.assertEqual(op.control_inputs, [c]) + */ + } + [Ignore("Todo: Port")] + [TestMethod] + public void TestWhileLoopWithExternalControlDep() + { + /* @test_util.run_v1_only("b/120545219") def testWhileLoopWithExternalControlDep(self): g = ops.Graph() @@ -203,8 +250,8 @@ namespace TensorFlowNET.UnitTest # External control dep is removed and replaced with internal control dep self.assertNotEqual(op.control_inputs[0], c.op) self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) + */ + } - - */ - } } +} diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 7dbf5e23..c997c9bf 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest /// public class PythonTest : Python { - public void assertItemsEqual(ICollection expected, ICollection given) + public void assertItemsEqual(ICollection given, ICollection expected) { Assert.IsNotNull(expected); Assert.IsNotNull(given);