diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index 202558d3..4c316ad3 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -122,13 +122,14 @@ namespace Tensorflow.Gradients [RegisterGradient("SquaredDifference")] public static Tensor[] _SquaredDifferenceGrad(Operation op, Tensor[] grads) { - //"""Returns the gradient for (x-y)^2.""" Tensor x = op.inputs[0]; Tensor y = op.inputs[1]; + var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype); + var x_grad = math_ops.scalar_mul(scale, grads[0]) * (x - y); return new Tensor[] { - x, - y + x_grad, + -x_grad }; } /// diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index ef7988fe..96bcc76f 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -272,6 +272,9 @@ namespace Tensorflow public static Tensor mul_no_nan(Tx x, Ty y, string name = null) => gen_math_ops.mul_no_nan(x, y, name: name); + public static Tensor scalar_mul(Tscale scale, Tx x, string name = null) + => tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(scale, x)); + public static Tensor real(Tensor input, string name = null) { return tf_with(ops.name_scope(name, "Real", new[] { input }), scope => diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 60b22f71..00c55fae 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -48,7 +48,7 @@ namespace Tensorflow public tensorflow() { Logger = new LoggerConfiguration() - .MinimumLevel.Error() + .MinimumLevel.Debug() .WriteTo.Console() .CreateLogger(); diff --git a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs index 3166da0f..9f07422d 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs +++ b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs @@ -25,7 +25,6 @@ namespace TensorFlowNET.UnitTest.Gradient Assert.AreEqual((float)grad, 3.0f); } - [Ignore] [TestMethod] public void SquaredDifference_Constant() {