diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index de030b70..e4bccea1 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -117,6 +117,13 @@ namespace Tensorflow
case Operation c1:
control_input_ops.Add(c1);
break;
+ case Tensor tensor:
+ control_input_ops.Add(tensor.op);
+ break;
+ // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
+ //case IndexedSlices islices:
+ // control_input_ops.Add(islices.op);
+ // break;
default:
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
}
diff --git a/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs
new file mode 100644
index 00000000..c2c40337
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs
@@ -0,0 +1,38 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow;
+
+namespace TensorFlowNET.UnitTest
+{
+ ///
+ /// tensorflow/python/framework/ops_test.py
+ ///
+ [TestClass]
+ public class ControlDependenciesTest : Python
+ {
+ [TestMethod]
+ public void TestBasic()
+ {
+ var graph = tf.Graph().as_default();
+ Tensor a=null, b = null, c = null, d = null, e = null;
+ with(graph, g =>
+ {
+ a = constant_op.constant(1.0);
+ b = constant_op.constant(1.0);
+ with(g.control_dependencies(new ITensorOrOperation[] {a}), x =>
+ {
+ c = constant_op.constant(1.0);
+ d = array_ops.identity(b);
+ e = array_ops.identity(c);
+ });
+ });
+ Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] {a.op}));
+ Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] {a.op}));
+ // e should be dominated by c.
+ Assert.AreEqual(0, e.op.control_inputs.Length);
+ }
+ }
+}