Browse Source

Graph: fixed bug with dependencies of nested with calls

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
cba10528bd
2 changed files with 15 additions and 36 deletions
  1. +3
    -3
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  2. +12
    -33
      test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs

+ 3
- 3
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
/// <returns>A list of control inputs for the op to be created.</returns> /// <returns>A list of control inputs for the op to be created.</returns>
private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops)
{ {
var ret = new ITensorOrOperation[0];
var ret = new List<ITensorOrOperation>();


foreach(var controller in _control_dependencies_stack) foreach(var controller in _control_dependencies_stack)
{ {
@@ -48,10 +48,10 @@ namespace Tensorflow
} }


if (!dominated) if (!dominated)
ret = controller.control_inputs.Where(x => !input_ops.Contains(x)).ToArray();
ret.AddRange( controller.control_inputs.Where(x => !input_ops.Contains(x)));
} }


return ret;
return ret.ToArray();
} }


/// <summary> /// <summary>


+ 12
- 33
test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs View File

@@ -65,8 +65,8 @@ namespace TensorFlowNET.UnitTest
} }
else else
{ {
var graph = tf.Graph();
with<Graph>(graph.as_default(), g =>
var graph = tf.Graph().as_default();
with<Graph>(graph, g =>
{ {
a = constant_op.constant(1.0); a = constant_op.constant(1.0);
b = future(); b = future();
@@ -122,7 +122,7 @@ def _apply_op(g, *args, **kwargs):
[TestMethod] [TestMethod]
public void TestBasicWithConversion() public void TestBasicWithConversion()
{ {
var g = ops.get_default_graph();
var g = tf.Graph().as_default();
// Note: _apply_op can be replaced by g.create_op // Note: _apply_op can be replaced by g.create_op
var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
// TODO: ConvertibleObj, see original source below // TODO: ConvertibleObj, see original source below
@@ -142,20 +142,20 @@ def _apply_op(g, *args, **kwargs):
self.assertEqual(c.op.control_inputs, [a.op]) self.assertEqual(c.op.control_inputs, [a.op])
*/ */
} }
[Ignore("Fails with message: Op type not registered 'FloatOutput' in binary running on ...")]
[TestMethod]
//[Ignore]
[TestMethod()]
public void TestNested() public void TestNested()
{ {
var g = ops.get_default_graph(); var g = ops.get_default_graph();
var a_1 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
var a_2 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
var a_3 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
var a_4 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
var a_1 = constant_op.constant(1.0);
var a_2 = constant_op.constant(3.0);
var a_3 = constant_op.constant(4.0);
var a_4 = constant_op.constant(5.0);
Operation b_1 = null, b_2 = null; Operation b_1 = null, b_2 = null;
with(g.control_dependencies(new ITensorOrOperation[] { a_1, a_2, a_3, a_4 }), ctrl => with(g.control_dependencies(new ITensorOrOperation[] { a_1, a_2, a_3, a_4 }), ctrl =>
{ {
b_1 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
b_1 = constant_op.constant(6.0);
}); });
with(g.control_dependencies(new ITensorOrOperation[] { a_1 }), ctrl1 => with(g.control_dependencies(new ITensorOrOperation[] { a_1 }), ctrl1 =>
{ {
@@ -165,34 +165,13 @@ def _apply_op(g, *args, **kwargs):
{ {
with(g.control_dependencies(new ITensorOrOperation[] { a_4 }), ctrl4 => with(g.control_dependencies(new ITensorOrOperation[] { a_4 }), ctrl4 =>
{ {
b_2 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
b_2 = constant_op.constant(7.0);
}); });
}); });
}); });
}); });
AssertItemsEqual(new[] {a_1.op, a_2.op, a_3.op, a_4.op}, b_1.op.control_inputs); AssertItemsEqual(new[] {a_1.op, a_2.op, a_3.op, a_4.op}, b_1.op.control_inputs);
AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs); AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
/*
def testNested(self):
g = ops.Graph()
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_1, a_2, a_3, a_4]):
b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_1]):
with g.control_dependencies([a_2]):
with g.control_dependencies([a_3]):
with g.control_dependencies([a_4]):
b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
b_1.op.control_inputs)
self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
*/
} }


Loading…
Cancel
Save