From 1f55e683d4f0948bd5d161443c3037261b59a55f Mon Sep 17 00:00:00 2001 From: MPnoy Date: Fri, 23 Apr 2021 00:40:57 +0300 Subject: [PATCH] GradientSliceTest --- .../ManagedAPI/GradientTest.cs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs index 87140b00..7595822b 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs @@ -46,5 +46,19 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var gr = gt.gradient(y, w); Assert.AreEqual(new float[] { 0, 0 }, gr.numpy()); } + + [TestMethod] + public void GradientSliceTest() + { + var X = tf.zeros(new Tensorflow.TensorShape(10)); + var W = tf.Variable(-0.06f, name: "weight"); + var b = tf.Variable(-0.73f, name: "bias"); + using var g = tf.GradientTape(); + var pred = W * X + b; + var test = tf.slice(pred, new[] { 0 }, pred.shape); + var gradients = g.gradient(test, (W, b)); + Assert.AreNotEqual(gradients.Item1, null); + Assert.AreNotEqual(gradients.Item2, null); + } } }