diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index de030b70..e4bccea1 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -117,6 +117,13 @@ namespace Tensorflow case Operation c1: control_input_ops.Add(c1); break; + case Tensor tensor: + control_input_ops.Add(tensor.op); + break; + // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented + //case IndexedSlices islices: + // control_input_ops.Add(islices.op); + // break; default: throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); } diff --git a/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs new file mode 100644 index 00000000..c2c40337 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + /// + /// tensorflow/python/framework/ops_test.py + /// + [TestClass] + public class ControlDependenciesTest : Python + { + [TestMethod] + public void TestBasic() + { + var graph = tf.Graph().as_default(); + Tensor a=null, b = null, c = null, d = null, e = null; + with(graph, g => + { + a = constant_op.constant(1.0); + b = constant_op.constant(1.0); + with(g.control_dependencies(new ITensorOrOperation[] {a}), x => + { + c = constant_op.constant(1.0); + d = array_ops.identity(b); + e = array_ops.identity(c); + }); + }); + Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] {a.op})); + Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] {a.op})); + // e should be dominated by c. + Assert.AreEqual(0, e.op.control_inputs.Length); + } + } +}