Browse Source

StopGradient: Adding basic gradient for stop_gradient function. (#357)

tags/v0.12
Sattisvar TANDABANY Haiping 6 years ago
parent
commit
235bad2575
2 changed files with 19 additions and 0 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  2. +13
    -0
      test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs

+ 6
- 0
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -196,6 +196,12 @@ namespace Tensorflow.Gradients
return new Tensor[] { _ReshapeToInput(op, grads[0]) }; return new Tensor[] { _ReshapeToInput(op, grads[0]) };
} }


[RegisterGradient("StopGradient")]
public static Tensor[] _NoGradient(Operation op, Tensor[] grads)
{
return new Tensor[] {null};
}

/// <summary> /// <summary>
/// Gradient for StridedSlice op. /// Gradient for StridedSlice op.
/// </summary> /// </summary>


+ 13
- 0
test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs View File

@@ -129,6 +129,19 @@ namespace TensorFlowNET.UnitTest.gradients_test
} }
} }


[TestMethod]
public void testStopGradientFunction()
{
var ap = tf.constant(1f);
var b = tf.tanh(ap) + gen_array_ops.stop_gradient(ap);
var g = tf.gradients(b, ap);
using (var sess = tf.Session())
{
var result = sess.run(g);
var actual = result[0].GetData<float>()[0];
self.assertEquals(0.41997434127f, actual);
}
}
[Ignore("TODO")] [Ignore("TODO")]
[TestMethod] [TestMethod]
public void testUnusedOutput() public void testUnusedOutput()


Loading…
Cancel
Save