From cd611f1750987b64bc56257ebc1353c3e47044a1 Mon Sep 17 00:00:00 2001 From: MPnoy Date: Fri, 23 Apr 2021 01:13:16 +0300 Subject: [PATCH] GradientConcatTest --- .../ManagedAPI/GradientTest.cs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs index 7595822b..28210cfc 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.ManagedAPI @@ -50,7 +51,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI [TestMethod] public void GradientSliceTest() { - var X = tf.zeros(new Tensorflow.TensorShape(10)); + var X = tf.zeros(new TensorShape(10)); var W = tf.Variable(-0.06f, name: "weight"); var b = tf.Variable(-0.73f, name: "bias"); using var g = tf.GradientTape(); @@ -60,5 +61,19 @@ namespace TensorFlowNET.UnitTest.ManagedAPI Assert.AreNotEqual(gradients.Item1, null); Assert.AreNotEqual(gradients.Item2, null); } + + [TestMethod] + public void GradientConcatTest() + { + var X = tf.zeros(new TensorShape(10)); + var W = tf.Variable(-0.06f, name: "weight"); + var b = tf.Variable(-0.73f, name: "bias"); + var test = tf.concat(new Tensor[] { W, b }, 0); + using var g = tf.GradientTape(); + var pred = test[0] * X + test[1]; + var gradients = g.gradient(pred, (W, b)); + Assert.AreEqual((float)gradients.Item1, 0); + Assert.AreEqual((float)gradients.Item2, 10); + } } }