diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
index e89c710a..917e3d1c 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
@@ -1,4 +1,5 @@
-using System.Linq;
+using System;
+using System.Linq;
using Tensorflow.Gradients;
using static Tensorflow.Binding;
using static Tensorflow.tensorflow;
@@ -37,7 +38,7 @@ namespace Tensorflow.Eager
}*/
}
- // Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
+ Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
if (!should_record) return should_record;
Tensor[] op_outputs;
diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs
index 3a6e9754..51956746 100644
--- a/src/TensorFlowNET.Core/Gradients/math_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs
@@ -212,8 +212,9 @@ namespace Tensorflow.Gradients
};
}
- var (sx, sy) = SmartBroadcastGradientArgs(x, y);
- var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
+ var broads = SmartBroadcastGradientArgs(x, y);
+ var (sx, rx, must_reduce_x) = broads[0];
+ var (sy, ry, must_reduce_y) = broads[1];
x = math_ops.conj(x);
y = math_ops.conj(y);
@@ -222,33 +223,21 @@ namespace Tensorflow.Gradients
if (op is EagerOperation op_eager1 &&
op_eager1.SkipInputIndices.Contains(0))
- {
- return new Tensor[]
- {
- gen_math_ops.mul(grad, math_ops.conj(y)),
- null
- };
- }
- // else if not must_reduce_x:
- // gx = gen_math_ops.mul(grad, y)
+ gy = null;
+ else if (!must_reduce_x)
+ gx = gen_math_ops.mul(grad, y);
else
- {
gx = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx);
- }
if (op is EagerOperation op_eager2 &&
op_eager2.SkipInputIndices.Contains(1))
- {
-
- }
- // else if not must_reduce_y:
- // gy = gen_math_ops.mul(x, grad)
+ gy = null;
+ else if (!must_reduce_y)
+ gy = gen_math_ops.mul(x, grad);
else
- {
gy = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy);
- }
return new Tensor[] { gx, gy };
}
@@ -479,8 +468,9 @@ namespace Tensorflow.Gradients
_ShapesFullySpecifiedAndEqual(x, y, grad))
return new Tensor[] { grad, -grad };
- var (sx, sy) = SmartBroadcastGradientArgs(x, y);
- var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
+ var broads = SmartBroadcastGradientArgs(x, y);
+ var (sx, rx, must_reduce_x) = broads[0];
+ var (sy, ry, must_reduce_y) = broads[1];
var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy);
@@ -728,8 +718,10 @@ namespace Tensorflow.Gradients
var z = op.outputs[0];
- var (sx, sy) = SmartBroadcastGradientArgs(x, y);
- var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
+ var broads = SmartBroadcastGradientArgs(x, y);
+ var (sx, rx, must_reduce_x) = broads[0];
+ var (sy, ry, must_reduce_y) = broads[1];
+
x = math_ops.conj(x);
y = math_ops.conj(y);
z = math_ops.conj(z);
@@ -761,7 +753,7 @@ namespace Tensorflow.Gradients
///
///
///
- private static (Tensor, Tensor) SmartBroadcastGradientArgs(Tensor x, Tensor y)
+ private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y)
{
Tensor sx, sy;
if (x.TensorShape.is_fully_defined() &&
@@ -769,6 +761,13 @@ namespace Tensorflow.Gradients
{
sx = array_ops.shape(x);
sy = array_ops.shape(y);
+
+ var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
+ return new[]
+ {
+ (sx, rx, true),
+ (sy, ry, true)
+ };
}
else
{
@@ -776,7 +775,7 @@ namespace Tensorflow.Gradients
sy = array_ops.shape_internal(y, optimize: false);
}
- return (sx, sy);
+ throw new NotImplementedException("");
}
}
}
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index 9d459e15..6b7819f9 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -12,7 +12,7 @@
Apache 2.0, Haiping Chen 2020
TensorFlow.Keras
https://github.com/SciSharp/TensorFlow.NET
- https://avatars3.githubusercontent.com/u/44989469?s=200&v=4
+ https://avatars3.githubusercontent.com/u/44989469?s=200&v=4
https://github.com/SciSharp/TensorFlow.NET
Keras for .NET is a C# version of Keras ported from the python version.
Keras for .NET
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs
index 0cf0d2f5..87140b00 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs
@@ -44,8 +44,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
using var gt = tf.GradientTape();
var y = x * w;
var gr = gt.gradient(y, w);
- Assert.AreNotEqual(null, gr);
+ Assert.AreEqual(new float[] { 0, 0 }, gr.numpy());
}
-
}
}