| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using NumSharp; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -65,15 +66,16 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
| [TestMethod] | [TestMethod] | ||||
| public void GradientConcatTest() | public void GradientConcatTest() | ||||
| { | { | ||||
| var X = tf.zeros(10); | |||||
| var W = tf.Variable(-0.06f, name: "weight"); | |||||
| var b = tf.Variable(-0.73f, name: "bias"); | |||||
| var test = tf.concat(new Tensor[] { W, b }, 0); | |||||
| var w1 = tf.Variable(new[] { new[] { 1f } }); | |||||
| var w2 = tf.Variable(new[] { new[] { 3f } }); | |||||
| using var g = tf.GradientTape(); | using var g = tf.GradientTape(); | ||||
| var pred = test[0] * X + test[1]; | |||||
| var gradients = g.gradient(pred, (W, b)); | |||||
| Assert.IsNull(gradients.Item1); | |||||
| Assert.IsNull(gradients.Item2); | |||||
| var w = tf.concat(new Tensor[] { w1, w2 }, 0); | |||||
| var x = tf.ones((1, 2)); | |||||
| var y = tf.reduce_sum(x, 1); | |||||
| var r = tf.matmul(w, x); | |||||
| var gradients = g.gradient(r, w); | |||||
| Assert.AreEqual((float)gradients[0][0], 2f); | |||||
| Assert.AreEqual((float)gradients[1][0], 2f); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||