diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index ee92b4ea..d9bc9b22 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -529,7 +529,12 @@ namespace Tensorflow.Gradients } else if (!input_0_shape.Contains(-1) && !tf.Context.executing_eagerly()) { - throw new NotImplementedException(""); + axes = axes.reshape(new Shape(-1)); + var shape_tensor = tf.constant(op.inputs[0].shape.as_int_list()); + var output_shape_kept_dims = math_ops.reduced_shape(shape_tensor, axes); + var tile_scaling = _safe_shape_div(shape_tensor, output_shape_kept_dims); + grad = array_ops.reshape(grad, output_shape_kept_dims); + return new Tensor[] { array_ops.tile(grad, tile_scaling), null }; } } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 049d874e..263509f6 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -585,9 +585,14 @@ namespace Tensorflow } public static Tensor tile(Tensor input, Tensor multiples, string name = null) - { - throw new NotImplementedException("tile"); - } + => tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + Tmultiples = op.get_attr("Tmultiples") + } + }); public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) { diff --git a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs index 12ad58e1..851a3bd7 100644 --- a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs @@ -178,6 +178,19 @@ namespace TensorFlowNET.UnitTest.Gradient [TestMethod] public void testReduceSumGradients() { + /* python code + import tensorflow.compat.v1 as tf + tf.disable_v2_behavior() + + x = tf.placeholder(tf.float64, shape = (1, 1)) + m = tf.broadcast_to(x, (2, 3)) + g0 = tf.gradients(tf.reduce_sum(m), x)[0] + g1 = tf.gradients(tf.reduce_sum(m, axis = 0), x)[0] + g2 = tf.gradients(tf.reduce_sum(m, axis = 1), x)[0] + with tf.compat.v1.Session() as sess: + (r0, r1, r2) = sess.run((g0, g1, g2), {x: [[1.0]]}) + */ + var x = tf.placeholder(tf.float64, shape: new Shape(1, 1)); var m = tf.broadcast_to(x, new Shape(2, 3)); var g0 = tf.gradients(tf.reduce_sum(m), x)[0]; @@ -186,10 +199,10 @@ namespace TensorFlowNET.UnitTest.Gradient using (var session = tf.Session()) { - var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, 1.0)); + var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } })); self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)"); - self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); - self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); + self.assertFloat64Equal(6.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); + self.assertFloat64Equal(6.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); } }