From cba10528bdb0c1e4f260935da8ab817cb0fdd47f Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Mon, 8 Apr 2019 23:32:25 +0200 Subject: [PATCH] Graph: fixed bug with dependencies of nested with calls --- .../Graphs/Graph.Control.cs | 6 +-- .../ControlDependenciesTest.cs | 45 +++++-------------- 2 files changed, 15 insertions(+), 36 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index d6fda591..bc1e15d5 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -30,7 +30,7 @@ namespace Tensorflow /// A list of control inputs for the op to be created. private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) { - var ret = new ITensorOrOperation[0]; + var ret = new List(); foreach(var controller in _control_dependencies_stack) { @@ -48,10 +48,10 @@ namespace Tensorflow } if (!dominated) - ret = controller.control_inputs.Where(x => !input_ops.Contains(x)).ToArray(); + ret.AddRange( controller.control_inputs.Where(x => !input_ops.Contains(x))); } - return ret; + return ret.ToArray(); } /// diff --git a/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs index 5146ae57..3187b37e 100644 --- a/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs +++ b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs @@ -65,8 +65,8 @@ namespace TensorFlowNET.UnitTest } else { - var graph = tf.Graph(); - with(graph.as_default(), g => + var graph = tf.Graph().as_default(); + with(graph, g => { a = constant_op.constant(1.0); b = future(); @@ -122,7 +122,7 @@ def _apply_op(g, *args, **kwargs): [TestMethod] public void TestBasicWithConversion() { - var g = ops.get_default_graph(); + var g = tf.Graph().as_default(); // Note: _apply_op can be replaced by g.create_op var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); // TODO: ConvertibleObj, see original source below @@ -142,20 +142,20 @@ def _apply_op(g, *args, **kwargs): self.assertEqual(c.op.control_inputs, [a.op]) */ } - - [Ignore("Fails with message: Op type not registered 'FloatOutput' in binary running on ...")] - [TestMethod] + + //[Ignore] + [TestMethod()] public void TestNested() { var g = ops.get_default_graph(); - var a_1 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); - var a_2 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); - var a_3 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); - var a_4 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); + var a_1 = constant_op.constant(1.0); + var a_2 = constant_op.constant(3.0); + var a_3 = constant_op.constant(4.0); + var a_4 = constant_op.constant(5.0); Operation b_1 = null, b_2 = null; with(g.control_dependencies(new ITensorOrOperation[] { a_1, a_2, a_3, a_4 }), ctrl => { - b_1 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); + b_1 = constant_op.constant(6.0); }); with(g.control_dependencies(new ITensorOrOperation[] { a_1 }), ctrl1 => { @@ -165,34 +165,13 @@ def _apply_op(g, *args, **kwargs): { with(g.control_dependencies(new ITensorOrOperation[] { a_4 }), ctrl4 => { - b_2 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); + b_2 = constant_op.constant(7.0); }); }); }); }); AssertItemsEqual(new[] {a_1.op, a_2.op, a_3.op, a_4.op}, b_1.op.control_inputs); AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs); - /* -def testNested(self): -g = ops.Graph() -a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) -a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) -a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) -a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - -with g.control_dependencies([a_1, a_2, a_3, a_4]): - b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - -with g.control_dependencies([a_1]): - with g.control_dependencies([a_2]): - with g.control_dependencies([a_3]): - with g.control_dependencies([a_4]): - b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - -self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op], - b_1.op.control_inputs) -self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) - */ }