From 0e0ff5860bb8256fd7e008a0022d039afaf16e33 Mon Sep 17 00:00:00 2001 From: MPnoy Date: Fri, 23 Apr 2021 00:40:57 +0300 Subject: [PATCH] GradientSliceTest --- .../ManagedAPI/GradientTest.cs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs index 87140b00..7595822b 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs @@ -46,5 +46,19 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var gr = gt.gradient(y, w); Assert.AreEqual(new float[] { 0, 0 }, gr.numpy()); } + + [TestMethod] + public void GradientSliceTest() + { + var X = tf.zeros(new Tensorflow.TensorShape(10)); + var W = tf.Variable(-0.06f, name: "weight"); + var b = tf.Variable(-0.73f, name: "bias"); + using var g = tf.GradientTape(); + var pred = W * X + b; + var test = tf.slice(pred, new[] { 0 }, pred.shape); + var gradients = g.gradient(test, (W, b)); + Assert.AreNotEqual(gradients.Item1, null); + Assert.AreNotEqual(gradients.Item2, null); + } } }