From 91399b15279cda0bcb7ba20678c84ec5bf4c2e6d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 14 Nov 2020 14:31:56 -0600 Subject: [PATCH] Fixed add GradientOperatorMulTest #642 --- .../Eager/EagerRunner.RecordGradient.cs | 5 +- src/TensorFlowNET.Core/Gradients/math_grad.cs | 51 +++++++++---------- .../Tensorflow.Keras.csproj | 2 +- .../ManagedAPI/GradientTest.cs | 3 +- 4 files changed, 30 insertions(+), 31 deletions(-) diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs index e89c710a..917e3d1c 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -1,4 +1,5 @@ -using System.Linq; +using System; +using System.Linq; using Tensorflow.Gradients; using static Tensorflow.Binding; using static Tensorflow.tensorflow; @@ -37,7 +38,7 @@ namespace Tensorflow.Eager }*/ } - // Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}"); + Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}"); if (!should_record) return should_record; Tensor[] op_outputs; diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 3a6e9754..51956746 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -212,8 +212,9 @@ namespace Tensorflow.Gradients }; } - var (sx, sy) = SmartBroadcastGradientArgs(x, y); - var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); + var broads = SmartBroadcastGradientArgs(x, y); + var (sx, rx, must_reduce_x) = broads[0]; + var (sy, ry, must_reduce_y) = broads[1]; x = math_ops.conj(x); y = math_ops.conj(y); @@ -222,33 +223,21 @@ namespace Tensorflow.Gradients if (op is EagerOperation op_eager1 && op_eager1.SkipInputIndices.Contains(0)) - { - return new Tensor[] - { - gen_math_ops.mul(grad, math_ops.conj(y)), - null - }; - } - // else if not must_reduce_x: - // gx = gen_math_ops.mul(grad, y) + gy = null; + else if (!must_reduce_x) + gx = gen_math_ops.mul(grad, y); else - { gx = array_ops.reshape( math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx); - } if (op is EagerOperation op_eager2 && op_eager2.SkipInputIndices.Contains(1)) - { - - } - // else if not must_reduce_y: - // gy = gen_math_ops.mul(x, grad) + gy = null; + else if (!must_reduce_y) + gy = gen_math_ops.mul(x, grad); else - { gy = array_ops.reshape( math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy); - } return new Tensor[] { gx, gy }; } @@ -479,8 +468,9 @@ namespace Tensorflow.Gradients _ShapesFullySpecifiedAndEqual(x, y, grad)) return new Tensor[] { grad, -grad }; - var (sx, sy) = SmartBroadcastGradientArgs(x, y); - var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); + var broads = SmartBroadcastGradientArgs(x, y); + var (sx, rx, must_reduce_x) = broads[0]; + var (sy, ry, must_reduce_y) = broads[1]; var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy); @@ -728,8 +718,10 @@ namespace Tensorflow.Gradients var z = op.outputs[0]; - var (sx, sy) = SmartBroadcastGradientArgs(x, y); - var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); + var broads = SmartBroadcastGradientArgs(x, y); + var (sx, rx, must_reduce_x) = broads[0]; + var (sy, ry, must_reduce_y) = broads[1]; + x = math_ops.conj(x); y = math_ops.conj(y); z = math_ops.conj(z); @@ -761,7 +753,7 @@ namespace Tensorflow.Gradients /// /// /// - private static (Tensor, Tensor) SmartBroadcastGradientArgs(Tensor x, Tensor y) + private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y) { Tensor sx, sy; if (x.TensorShape.is_fully_defined() && @@ -769,6 +761,13 @@ namespace Tensorflow.Gradients { sx = array_ops.shape(x); sy = array_ops.shape(y); + + var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); + return new[] + { + (sx, rx, true), + (sy, ry, true) + }; } else { @@ -776,7 +775,7 @@ namespace Tensorflow.Gradients sy = array_ops.shape_internal(y, optimize: false); } - return (sx, sy); + throw new NotImplementedException(""); } } } diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 9d459e15..6b7819f9 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -12,7 +12,7 @@ Apache 2.0, Haiping Chen 2020 TensorFlow.Keras https://github.com/SciSharp/TensorFlow.NET - https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 https://github.com/SciSharp/TensorFlow.NET Keras for .NET is a C# version of Keras ported from the python version. Keras for .NET diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs index 0cf0d2f5..87140b00 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs @@ -44,8 +44,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI using var gt = tf.GradientTape(); var y = x * w; var gr = gt.gradient(y, w); - Assert.AreNotEqual(null, gr); + Assert.AreEqual(new float[] { 0, 0 }, gr.numpy()); } - } }