Browse Source

Fixed add GradientOperatorMulTest #642

tags/v0.30
Oceania2018 5 years ago
parent
commit
91399b1527
4 changed files with 30 additions and 31 deletions
  1. +3
    -2
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  2. +25
    -26
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  3. +1
    -1
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  4. +1
    -2
      test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs

+ 3
- 2
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -1,4 +1,5 @@
using System.Linq;
using System;
using System.Linq;
using Tensorflow.Gradients; using Tensorflow.Gradients;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using static Tensorflow.tensorflow; using static Tensorflow.tensorflow;
@@ -37,7 +38,7 @@ namespace Tensorflow.Eager
}*/ }*/
} }


// Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
if (!should_record) return should_record; if (!should_record) return should_record;


Tensor[] op_outputs; Tensor[] op_outputs;


+ 25
- 26
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -212,8 +212,9 @@ namespace Tensorflow.Gradients
}; };
} }


var (sx, sy) = SmartBroadcastGradientArgs(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
var broads = SmartBroadcastGradientArgs(x, y);
var (sx, rx, must_reduce_x) = broads[0];
var (sy, ry, must_reduce_y) = broads[1];


x = math_ops.conj(x); x = math_ops.conj(x);
y = math_ops.conj(y); y = math_ops.conj(y);
@@ -222,33 +223,21 @@ namespace Tensorflow.Gradients


if (op is EagerOperation op_eager1 && if (op is EagerOperation op_eager1 &&
op_eager1.SkipInputIndices.Contains(0)) op_eager1.SkipInputIndices.Contains(0))
{
return new Tensor[]
{
gen_math_ops.mul(grad, math_ops.conj(y)),
null
};
}
// else if not must_reduce_x:
// gx = gen_math_ops.mul(grad, y)
gy = null;
else if (!must_reduce_x)
gx = gen_math_ops.mul(grad, y);
else else
{
gx = array_ops.reshape( gx = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx); math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx);
}


if (op is EagerOperation op_eager2 && if (op is EagerOperation op_eager2 &&
op_eager2.SkipInputIndices.Contains(1)) op_eager2.SkipInputIndices.Contains(1))
{

}
// else if not must_reduce_y:
// gy = gen_math_ops.mul(x, grad)
gy = null;
else if (!must_reduce_y)
gy = gen_math_ops.mul(x, grad);
else else
{
gy = array_ops.reshape( gy = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy); math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy);
}


return new Tensor[] { gx, gy }; return new Tensor[] { gx, gy };
} }
@@ -479,8 +468,9 @@ namespace Tensorflow.Gradients
_ShapesFullySpecifiedAndEqual(x, y, grad)) _ShapesFullySpecifiedAndEqual(x, y, grad))
return new Tensor[] { grad, -grad }; return new Tensor[] { grad, -grad };


var (sx, sy) = SmartBroadcastGradientArgs(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
var broads = SmartBroadcastGradientArgs(x, y);
var (sx, rx, must_reduce_x) = broads[0];
var (sy, ry, must_reduce_y) = broads[1];


var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy); var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy);
@@ -728,8 +718,10 @@ namespace Tensorflow.Gradients


var z = op.outputs[0]; var z = op.outputs[0];


var (sx, sy) = SmartBroadcastGradientArgs(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
var broads = SmartBroadcastGradientArgs(x, y);
var (sx, rx, must_reduce_x) = broads[0];
var (sy, ry, must_reduce_y) = broads[1];

x = math_ops.conj(x); x = math_ops.conj(x);
y = math_ops.conj(y); y = math_ops.conj(y);
z = math_ops.conj(z); z = math_ops.conj(z);
@@ -761,7 +753,7 @@ namespace Tensorflow.Gradients
/// <param name="x"></param> /// <param name="x"></param>
/// <param name="y"></param> /// <param name="y"></param>
/// <returns></returns> /// <returns></returns>
private static (Tensor, Tensor) SmartBroadcastGradientArgs(Tensor x, Tensor y)
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y)
{ {
Tensor sx, sy; Tensor sx, sy;
if (x.TensorShape.is_fully_defined() && if (x.TensorShape.is_fully_defined() &&
@@ -769,6 +761,13 @@ namespace Tensorflow.Gradients
{ {
sx = array_ops.shape(x); sx = array_ops.shape(x);
sy = array_ops.shape(y); sy = array_ops.shape(y);

var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
return new[]
{
(sx, rx, true),
(sy, ry, true)
};
} }
else else
{ {
@@ -776,7 +775,7 @@ namespace Tensorflow.Gradients
sy = array_ops.shape_internal(y, optimize: false); sy = array_ops.shape_internal(y, optimize: false);
} }


return (sx, sy);
throw new NotImplementedException("");
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -12,7 +12,7 @@
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright> <Copyright>Apache 2.0, Haiping Chen 2020</Copyright>
<PackageId>TensorFlow.Keras</PackageId> <PackageId>TensorFlow.Keras</PackageId>
<PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl> <PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl>
<PackageIcon>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIcon>
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> <RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl>
<PackageReleaseNotes>Keras for .NET is a C# version of Keras ported from the python version.</PackageReleaseNotes> <PackageReleaseNotes>Keras for .NET is a C# version of Keras ported from the python version.</PackageReleaseNotes>
<Description>Keras for .NET <Description>Keras for .NET


+ 1
- 2
test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs View File

@@ -44,8 +44,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
using var gt = tf.GradientTape(); using var gt = tf.GradientTape();
var y = x * w; var y = x * w;
var gr = gt.gradient(y, w); var gr = gt.gradient(y, w);
Assert.AreNotEqual(null, gr);
Assert.AreEqual(new float[] { 0, 0 }, gr.numpy());
} }

} }
} }

Loading…
Cancel
Save