diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs
index 11dfcc89..9fe2aeed 100644
--- a/src/TensorFlowNET.Core/Gradients/array_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs
@@ -196,6 +196,12 @@ namespace Tensorflow.Gradients
return new Tensor[] { _ReshapeToInput(op, grads[0]) };
}
+ [RegisterGradient("StopGradient")]
+ public static Tensor[] _NoGradient(Operation op, Tensor[] grads)
+ {
+ return new Tensor[] {null};
+ }
+
///
/// Gradient for StridedSlice op.
///
diff --git a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs
index 68ee14e4..82306d9c 100644
--- a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs
+++ b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs
@@ -129,6 +129,19 @@ namespace TensorFlowNET.UnitTest.gradients_test
}
}
+ [TestMethod]
+ public void testStopGradientFunction()
+ {
+ var ap = tf.constant(1f);
+ var b = tf.tanh(ap) + gen_array_ops.stop_gradient(ap);
+ var g = tf.gradients(b, ap);
+ using (var sess = tf.Session())
+ {
+ var result = sess.run(g);
+ var actual = result[0].GetData()[0];
+ self.assertEquals(0.41997434127f, actual);
+ }
+ }
[Ignore("TODO")]
[TestMethod]
public void testUnusedOutput()