Browse Source

GradientSliceTest

pull/801/head
MPnoy 4 years ago
parent
commit
1f55e683d4
1 changed files with 14 additions and 0 deletions
  1. +14
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs

+ 14
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs View File

@@ -46,5 +46,19 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var gr = gt.gradient(y, w); var gr = gt.gradient(y, w);
Assert.AreEqual(new float[] { 0, 0 }, gr.numpy()); Assert.AreEqual(new float[] { 0, 0 }, gr.numpy());
} }

[TestMethod]
public void GradientSliceTest()
{
var X = tf.zeros(new Tensorflow.TensorShape(10));
var W = tf.Variable(-0.06f, name: "weight");
var b = tf.Variable(-0.73f, name: "bias");
using var g = tf.GradientTape();
var pred = W * X + b;
var test = tf.slice(pred, new[] { 0 }, pred.shape);
var gradients = g.gradient(test, (W, b));
Assert.AreNotEqual(gradients.Item1, null);
Assert.AreNotEqual(gradients.Item2, null);
}
} }
} }

Loading…
Cancel
Save