diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 11dfcc89..9fe2aeed 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -196,6 +196,12 @@ namespace Tensorflow.Gradients return new Tensor[] { _ReshapeToInput(op, grads[0]) }; } + [RegisterGradient("StopGradient")] + public static Tensor[] _NoGradient(Operation op, Tensor[] grads) + { + return new Tensor[] {null}; + } + /// /// Gradient for StridedSlice op. /// diff --git a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs index 68ee14e4..82306d9c 100644 --- a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs +++ b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs @@ -129,6 +129,19 @@ namespace TensorFlowNET.UnitTest.gradients_test } } + [TestMethod] + public void testStopGradientFunction() + { + var ap = tf.constant(1f); + var b = tf.tanh(ap) + gen_array_ops.stop_gradient(ap); + var g = tf.gradients(b, ap); + using (var sess = tf.Session()) + { + var result = sess.run(g); + var actual = result[0].GetData()[0]; + self.assertEquals(0.41997434127f, actual); + } + } [Ignore("TODO")] [TestMethod] public void testUnusedOutput()