| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.ManagedAPI | namespace TensorFlowNET.UnitTest.ManagedAPI | ||||
| @@ -50,7 +51,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
| [TestMethod] | [TestMethod] | ||||
| public void GradientSliceTest() | public void GradientSliceTest() | ||||
| { | { | ||||
| var X = tf.zeros(new Tensorflow.TensorShape(10)); | |||||
| var X = tf.zeros(new TensorShape(10)); | |||||
| var W = tf.Variable(-0.06f, name: "weight"); | var W = tf.Variable(-0.06f, name: "weight"); | ||||
| var b = tf.Variable(-0.73f, name: "bias"); | var b = tf.Variable(-0.73f, name: "bias"); | ||||
| using var g = tf.GradientTape(); | using var g = tf.GradientTape(); | ||||
| @@ -60,5 +61,19 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
| Assert.AreNotEqual(gradients.Item1, null); | Assert.AreNotEqual(gradients.Item1, null); | ||||
| Assert.AreNotEqual(gradients.Item2, null); | Assert.AreNotEqual(gradients.Item2, null); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void GradientConcatTest() | |||||
| { | |||||
| var X = tf.zeros(new TensorShape(10)); | |||||
| var W = tf.Variable(-0.06f, name: "weight"); | |||||
| var b = tf.Variable(-0.73f, name: "bias"); | |||||
| var test = tf.concat(new Tensor[] { W, b }, 0); | |||||
| using var g = tf.GradientTape(); | |||||
| var pred = test[0] * X + test[1]; | |||||
| var gradients = g.gradient(pred, (W, b)); | |||||
| Assert.AreEqual((float)gradients.Item1, 0); | |||||
| Assert.AreEqual((float)gradients.Item2, 10); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||