Browse Source

Merge pull request #1010 from AsakusaRinne/fix_mean_square_error_grad

Add shape deduce for mean square error grad.
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
86eb48bfd3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 6 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  2. +28
    -5
      src/TensorFlowNET.Core/Gradients/nn_grad.cs

+ 1
- 1
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -840,7 +840,7 @@ namespace Tensorflow.Gradients
/// <param name="x"></param> /// <param name="x"></param>
/// <param name="y"></param> /// <param name="y"></param>
/// <returns></returns> /// <returns></returns>
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
public static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
{ {
Tensor sx, sy; Tensor sx, sy;
if (x.shape.IsFullyDefined && if (x.shape.IsFullyDefined &&


+ 28
- 5
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.Diagnostics;
using System.Linq; using System.Linq;
using Tensorflow.Operations; using Tensorflow.Operations;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -135,13 +136,35 @@ namespace Tensorflow.Gradients
{ {
Tensor x = op.inputs[0]; Tensor x = op.inputs[0];
Tensor y = op.inputs[1]; Tensor y = op.inputs[1];
var grad = grads[0];
var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype); var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype);
var x_grad = math_ops.scalar_mul(scale, grads[0]) * (x - y);
return new Tensor[]
var x_grad = math_ops.scalar_mul(scale, grad) * (x - y);
if (math_grad._ShapesFullySpecifiedAndEqual(x, y, grad))
{ {
x_grad,
-x_grad
};
return new Tensor[] { x_grad, -x_grad };
}
var broadcast_info = math_grad.SmartBroadcastGradientArgs(x, y, grad);
Debug.Assert(broadcast_info.Length == 2);
var (sx, rx, must_reduce_x) = broadcast_info[0];
var (sy, ry, must_reduce_y) = broadcast_info[1];
Tensor gx, gy;
if (must_reduce_x)
{
gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx);
}
else
{
gx = x_grad;
}
if (must_reduce_y)
{
gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy);
}
else
{
gy = -x_grad;
}
return new Tensor[] { gx, gy };
} }


/// <summary> /// <summary>


Loading…
Cancel
Save