| @@ -46,5 +46,19 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
| var gr = gt.gradient(y, w); | var gr = gt.gradient(y, w); | ||||
| Assert.AreEqual(new float[] { 0, 0 }, gr.numpy()); | Assert.AreEqual(new float[] { 0, 0 }, gr.numpy()); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void GradientSliceTest() | |||||
| { | |||||
| var X = tf.zeros(new Tensorflow.TensorShape(10)); | |||||
| var W = tf.Variable(-0.06f, name: "weight"); | |||||
| var b = tf.Variable(-0.73f, name: "bias"); | |||||
| using var g = tf.GradientTape(); | |||||
| var pred = W * X + b; | |||||
| var test = tf.slice(pred, new[] { 0 }, pred.shape); | |||||
| var gradients = g.gradient(test, (W, b)); | |||||
| Assert.AreNotEqual(gradients.Item1, null); | |||||
| Assert.AreNotEqual(gradients.Item2, null); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||