From 10fda1f2a41f1d8becda26ac8af08c481f820dff Mon Sep 17 00:00:00 2001 From: Alexander Mishunin Date: Mon, 14 Mar 2022 10:23:32 +0300 Subject: [PATCH] Fix reduce_sum test case --- .../GradientTest/GradientTest.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs index 851a3bd7..f60fe6d9 100644 --- a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs @@ -185,8 +185,8 @@ namespace TensorFlowNET.UnitTest.Gradient x = tf.placeholder(tf.float64, shape = (1, 1)) m = tf.broadcast_to(x, (2, 3)) g0 = tf.gradients(tf.reduce_sum(m), x)[0] - g1 = tf.gradients(tf.reduce_sum(m, axis = 0), x)[0] - g2 = tf.gradients(tf.reduce_sum(m, axis = 1), x)[0] + g1 = tf.gradients(tf.reduce_sum(m, axis = 0)[0], x)[0] + g2 = tf.gradients(tf.reduce_sum(m, axis = 1)[0], x)[0] with tf.compat.v1.Session() as sess: (r0, r1, r2) = sess.run((g0, g1, g2), {x: [[1.0]]}) */ @@ -194,15 +194,15 @@ namespace TensorFlowNET.UnitTest.Gradient var x = tf.placeholder(tf.float64, shape: new Shape(1, 1)); var m = tf.broadcast_to(x, new Shape(2, 3)); var g0 = tf.gradients(tf.reduce_sum(m), x)[0]; - var g1 = tf.gradients(tf.reduce_sum(m, axis: 0), x)[0]; - var g2 = tf.gradients(tf.reduce_sum(m, axis: 1), x)[0]; + var g1 = tf.gradients(tf.reduce_sum(m, axis: 0)[0], x)[0]; + var g2 = tf.gradients(tf.reduce_sum(m, axis: 1)[0], x)[0]; using (var session = tf.Session()) { var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } })); self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)"); - self.assertFloat64Equal(6.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); - self.assertFloat64Equal(6.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); + self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); + self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); } }