Browse Source

Fix gradient of squared_difference #787

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
f55650d9e1
4 changed files with 8 additions and 5 deletions
  1. +4
    -3
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  3. +1
    -1
      src/TensorFlowNET.Core/tensorflow.cs
  4. +0
    -1
      test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs

+ 4
- 3
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -122,13 +122,14 @@ namespace Tensorflow.Gradients
[RegisterGradient("SquaredDifference")]
public static Tensor[] _SquaredDifferenceGrad(Operation op, Tensor[] grads)
{
//"""Returns the gradient for (x-y)^2."""
Tensor x = op.inputs[0];
Tensor y = op.inputs[1];
var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype);
var x_grad = math_ops.scalar_mul(scale, grads[0]) * (x - y);
return new Tensor[]
{
x,
y
x_grad,
-x_grad
};
}
/// <summary>


+ 3
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -272,6 +272,9 @@ namespace Tensorflow
public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.mul_no_nan(x, y, name: name);

public static Tensor scalar_mul<Tscale, Tx>(Tscale scale, Tx x, string name = null)
=> tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(scale, x));

public static Tensor real(Tensor input, string name = null)
{
return tf_with(ops.name_scope(name, "Real", new[] { input }), scope =>


+ 1
- 1
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow
public tensorflow()
{
Logger = new LoggerConfiguration()
.MinimumLevel.Error()
.MinimumLevel.Debug()
.WriteTo.Console()
.CreateLogger();



+ 0
- 1
test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs View File

@@ -25,7 +25,6 @@ namespace TensorFlowNET.UnitTest.Gradient
Assert.AreEqual((float)grad, 3.0f);
}

[Ignore]
[TestMethod]
public void SquaredDifference_Constant()
{


Loading…
Cancel
Save