From 99bd08b1da3f8ea34d15e46c8329d9c967625009 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 23 Jul 2019 20:03:16 -0500 Subject: [PATCH] fix object reference issue for _AggregatedGrads #303 --- src/TensorFlowHub/TensorFlowHub.csproj | 4 +-- .../Gradients/gradients_util.cs | 21 ++++++------ .../ImageProcessing/DigitRecognitionRNN.cs | 2 +- test/TensorFlowNET.UnitTest/GradientTest.cs | 33 +++++++++++++++++++ 4 files changed, 46 insertions(+), 14 deletions(-) diff --git a/src/TensorFlowHub/TensorFlowHub.csproj b/src/TensorFlowHub/TensorFlowHub.csproj index ffd4e11f..ddf8ca25 100644 --- a/src/TensorFlowHub/TensorFlowHub.csproj +++ b/src/TensorFlowHub/TensorFlowHub.csproj @@ -1,4 +1,4 @@ - + TensorFlow.Net.Hub Tensorflow.Hub @@ -8,7 +8,7 @@ - + diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 144179d9..3e556da0 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -137,7 +137,7 @@ namespace Tensorflow if (loop_state != null) ; else - out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i); + out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; } } @@ -146,7 +146,7 @@ namespace Tensorflow string name1 = scope1; if (grad_fn != null) { - in_grads = _MaybeCompile(grad_scope, op, out_grads, null, grad_fn); + in_grads = _MaybeCompile(grad_scope, op, out_grads[0].ToArray(), null, grad_fn); _VerifyGeneratedGradients(in_grads, op); } @@ -310,10 +310,9 @@ namespace Tensorflow yield return op.inputs[i]; } - private static Tensor[] _AggregatedGrads(Dictionary>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) + private static List> _AggregatedGrads(Dictionary>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) { var out_grads = _GetGrads(grads, op); - var return_grads = new Tensor[out_grads.Count]; foreach (var (i, out_grad) in enumerate(out_grads)) { @@ -334,21 +333,21 @@ namespace Tensorflow throw new ValueError("_AggregatedGrads out_grad.Length == 0"); } - return_grads[i] = out_grad[0]; + out_grads[i] = out_grad; } else { used = "add_n"; - return_grads[i] = _MultiDeviceAddN(out_grad.ToArray(), gradient_uid); + out_grads[i] = new List { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) }; } } else { - return_grads[i] = null; + out_grads[i] = null; } } - return return_grads; + return out_grads; } /// @@ -362,7 +361,7 @@ namespace Tensorflow // Basic function structure comes from control_flow_ops.group(). // Sort tensors according to their devices. var tensors_on_device = new Dictionary>(); - + foreach (var tensor in tensor_list) { if (!tensors_on_device.ContainsKey(tensor.Device)) @@ -370,10 +369,10 @@ namespace Tensorflow tensors_on_device[tensor.Device].Add(tensor); } - + // For each device, add the tensors on that device first. var summands = new List(); - foreach(var dev in tensors_on_device.Keys) + foreach (var dev in tensors_on_device.Keys) { var tensors = tensors_on_device[dev]; ops._colocate_with_for_gradient(tensors[0].op, gradient_uid, ignore_existing: true); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs index 796d6103..d51ca9ad 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs @@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples.ImageProcess /// public class DigitRecognitionRNN : IExample { - public bool Enabled { get; set; } = true; + public bool Enabled { get; set; } = false; public bool IsImportingGraph { get; set; } = false; public string Name => "MNIST RNN"; diff --git a/test/TensorFlowNET.UnitTest/GradientTest.cs b/test/TensorFlowNET.UnitTest/GradientTest.cs index 6a1f77a2..38d728fd 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/GradientTest.cs @@ -2,6 +2,7 @@ using NumSharp; using System.Linq; using Tensorflow; +using static Tensorflow.Python; namespace TensorFlowNET.UnitTest { @@ -28,6 +29,38 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(g[1].name, "gradients/Fill:0"); } + [TestMethod] + public void Gradient2x() + { + var graph = tf.Graph().as_default(); + with(tf.Session(graph), sess => { + var x = tf.constant(7.0f); + var y = x * x * tf.constant(0.1f); + + var grad = tf.gradients(y, x); + Assert.AreEqual(grad[0].name, "gradients/AddN:0"); + + float r = sess.run(grad[0]); + Assert.AreEqual(r, 1.4f); + }); + } + + [TestMethod] + public void Gradient3x() + { + var graph = tf.Graph().as_default(); + with(tf.Session(graph), sess => { + var x = tf.constant(7.0f); + var y = x * x * x * tf.constant(0.1f); + + var grad = tf.gradients(y, x); + Assert.AreEqual(grad[0].name, "gradients/AddN:0"); + + float r = sess.run(grad[0]); + Assert.AreEqual(r, 14.700001f); + }); + } + [TestMethod] public void StridedSlice() {