Browse Source

extend gradient function capability.

tags/v0.9
Oceania2018 6 years ago
parent
commit
f13e35d760
8 changed files with 107 additions and 64 deletions
  1. +13
    -0
      docs/source/Gradient.md
  2. +16
    -0
      src/TensorFlowNET.Core/Gradients/RegisterGradient.cs
  3. +5
    -1
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  4. +3
    -2
      src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
  5. +16
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  6. +7
    -0
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  7. +45
    -60
      src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
  8. +2
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

+ 13
- 0
docs/source/Gradient.md View File

@@ -1,2 +1,15 @@
# Chapter. Gradient # Chapter. Gradient


### Register custom gradient function

TF.NET is extensible which can be added custom gradient function.

```csharp
// define gradient function
ops.RegisterGradientFunction("ConcatV2", (oper, out_grads) =>
{
var grad = grads[0];
return new Tensor[]{ };
});
```


+ 16
- 0
src/TensorFlowNET.Core/Gradients/RegisterGradient.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Gradients
{
public class RegisterGradient : Attribute
{
public string Name { get; set; }

public RegisterGradient(string name)
{
Name = name;
}
}
}

+ 5
- 1
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -10,8 +10,10 @@ namespace Tensorflow.Gradients
/// <summary> /// <summary>
/// tensorflow\python\ops\array_grad.py /// tensorflow\python\ops\array_grad.py
/// </summary> /// </summary>
[RegisterGradient("array_grad")]
public class array_grad public class array_grad
{ {
[RegisterGradient("ConcatV2")]
public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads) public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -123,12 +125,13 @@ namespace Tensorflow.Gradients
return gen_ops.shape_n(inputs); return gen_ops.shape_n(inputs);
} }


[RegisterGradient("Reshape")]
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads) public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
{ {
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
} }


[RegisterGradient("Squeeze")]
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
{ {
return new Tensor[] { _ReshapeToInput(op, grads[0]) }; return new Tensor[] { _ReshapeToInput(op, grads[0]) };
@@ -139,6 +142,7 @@ namespace Tensorflow.Gradients
return array_ops.reshape(grad, array_ops.shape(op.inputs[0])); return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
} }


[RegisterGradient("Transpose")]
public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads) public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
{ {
var p = op.inputs[1]; var p = op.inputs[1];


+ 3
- 2
src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs View File

@@ -69,11 +69,12 @@ namespace Tensorflow.Gradients
// false_grad = switch(grad[0], op.inputs[1])[0] // false_grad = switch(grad[0], op.inputs[1])[0]
// true_grad = switch(grad[1], op.inputs[1])[1] // true_grad = switch(grad[1], op.inputs[1])[1]
// return merge([false_grad, true_grad])[0], None // return merge([false_grad, true_grad])[0], None
}
}
/// <summary> /// <summary>
/// Gradients for a Merge op are calculated using a Switch op. /// Gradients for a Merge op are calculated using a Switch op.
/// </summary> /// </summary>
[RegisterGradient("Merge")]
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];


+ 16
- 0
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -10,8 +10,10 @@ namespace Tensorflow.Gradients
/// <summary> /// <summary>
/// Gradients for operators defined in math_ops.py. /// Gradients for operators defined in math_ops.py.
/// </summary> /// </summary>
[RegisterGradient("math_grad")]
public class math_grad public class math_grad
{ {
[RegisterGradient("Add")]
public static Tensor[] _AddGrad(Operation op, Tensor[] grads) public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
{ {
var x = op.inputs[0]; var x = op.inputs[0];
@@ -32,6 +34,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { r1, r2 }; return new Tensor[] { r1, r2 };
} }


[RegisterGradient("DivNoNan")]
public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads) public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -59,6 +62,7 @@ namespace Tensorflow.Gradients
/// <param name="op"></param> /// <param name="op"></param>
/// <param name="grads"></param> /// <param name="grads"></param>
/// <returns></returns> /// <returns></returns>
[RegisterGradient("Exp")]
public static Tensor[] _ExpGrad(Operation op, Tensor[] grads) public static Tensor[] _ExpGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -69,11 +73,13 @@ namespace Tensorflow.Gradients
}); });
} }


[RegisterGradient("Identity")]
public static Tensor[] _IdGrad(Operation op, Tensor[] grads) public static Tensor[] _IdGrad(Operation op, Tensor[] grads)
{ {
return new Tensor[] { grads[0] }; return new Tensor[] { grads[0] };
} }


[RegisterGradient("Log")]
public static Tensor[] _LogGrad(Operation op, Tensor[] grads) public static Tensor[] _LogGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -84,6 +90,7 @@ namespace Tensorflow.Gradients
}); });
} }


[RegisterGradient("Mul")]
public static Tensor[] _MulGrad(Operation op, Tensor[] grads) public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
{ {
var x = op.inputs[0]; var x = op.inputs[0];
@@ -112,6 +119,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { reshape1, reshape2 }; return new Tensor[] { reshape1, reshape2 };
} }


[RegisterGradient("MatMul")]
public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads) public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -145,6 +153,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { grad_a, grad_b }; return new Tensor[] { grad_a, grad_b };
} }


[RegisterGradient("Mean")]
public static Tensor[] _MeanGrad(Operation op, Tensor[] grads) public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -159,6 +168,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null }; return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null };
} }


[RegisterGradient("Neg")]
public static Tensor[] _NegGrad(Operation op, Tensor[] grads) public static Tensor[] _NegGrad(Operation op, Tensor[] grads)
{ {
return new Tensor[] { -grads[0] }; return new Tensor[] { -grads[0] };
@@ -169,6 +179,7 @@ namespace Tensorflow.Gradients
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1)); return math_ops.floordiv(x, gen_math_ops.maximum(y, 1));
} }


[RegisterGradient("Sub")]
public static Tensor[] _SubGrad(Operation op, Tensor[] grads) public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -198,6 +209,7 @@ namespace Tensorflow.Gradients
!x_shape.Contains(-1); !x_shape.Contains(-1);
} }


[RegisterGradient("Sum")]
public static Tensor[] _SumGrad(Operation op, Tensor[] grads) public static Tensor[] _SumGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -231,6 +243,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null }; return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null };
} }


[RegisterGradient("RealDiv")]
public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads) public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -254,6 +267,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { reshape2, reshape1 }; return new Tensor[] { reshape2, reshape1 };
} }


[RegisterGradient("Sigmoid")]
public static Tensor[] _SigmoidGrad(Operation op, Tensor[] grads) public static Tensor[] _SigmoidGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -266,6 +280,7 @@ namespace Tensorflow.Gradients
}); });
} }


[RegisterGradient("Square")]
public static Tensor[] _SquareGrad(Operation op, Tensor[] grads) public static Tensor[] _SquareGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -279,6 +294,7 @@ namespace Tensorflow.Gradients
}); });
} }


[RegisterGradient("Pow")]
public static Tensor[] _PowGrad(Operation op, Tensor[] grads) public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];


+ 7
- 0
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -9,6 +9,7 @@ namespace Tensorflow.Gradients
/// <summary> /// <summary>
/// ///
/// </summary> /// </summary>
[RegisterGradient("math_grad")]
public class nn_grad public class nn_grad
{ {
/// <summary> /// <summary>
@@ -17,6 +18,7 @@ namespace Tensorflow.Gradients
/// <param name="op"></param> /// <param name="op"></param>
/// <param name="grad"></param> /// <param name="grad"></param>
/// <returns></returns> /// <returns></returns>
[RegisterGradient("BiasAdd")]
public static Tensor[] _BiasAddGrad(Operation op, Tensor[] grads) public static Tensor[] _BiasAddGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];
@@ -25,6 +27,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { grad, bias_add_grad }; return new Tensor[] { grad, bias_add_grad };
} }


[RegisterGradient("Relu")]
public static Tensor[] _ReluGrad(Operation op, Tensor[] grads) public static Tensor[] _ReluGrad(Operation op, Tensor[] grads)
{ {
return new Tensor[] { gen_nn_ops.relu_grad(grads[0], op.outputs[0]) }; return new Tensor[] { gen_nn_ops.relu_grad(grads[0], op.outputs[0]) };
@@ -36,6 +39,7 @@ namespace Tensorflow.Gradients
/// <param name="op"></param> /// <param name="op"></param>
/// <param name="grads"></param> /// <param name="grads"></param>
/// <returns></returns> /// <returns></returns>
[RegisterGradient("Softmax")]
public static Tensor[] _SoftmaxGrad(Operation op, Tensor[] grads) public static Tensor[] _SoftmaxGrad(Operation op, Tensor[] grads)
{ {
var grad_softmax = grads[0]; var grad_softmax = grads[0];
@@ -54,6 +58,7 @@ namespace Tensorflow.Gradients
/// <param name="grad_loss"></param> /// <param name="grad_loss"></param>
/// <param name="grad_grad"></param> /// <param name="grad_grad"></param>
/// <returns></returns> /// <returns></returns>
[RegisterGradient("SoftmaxCrossEntropyWithLogits")]
public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads) public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads)
{ {
var grad_loss = grads[0]; var grad_loss = grads[0];
@@ -74,6 +79,7 @@ namespace Tensorflow.Gradients
}; };
} }


[RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")]
public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads) public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads)
{ {
var sparse_softmax_grad_without_gradient = array_ops.prevent_gradient( var sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
@@ -111,6 +117,7 @@ namespace Tensorflow.Gradients
/// <param name="op"></param> /// <param name="op"></param>
/// <param name="grads"></param> /// <param name="grads"></param>
/// <returns></returns> /// <returns></returns>
[RegisterGradient("TopK")]
public static Tensor[] _TopKGrad(Operation op, Tensor[] grads) public static Tensor[] _TopKGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0]; var grad = grads[0];


+ 45
- 60
src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs View File

@@ -1,5 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text; using System.Text;
using Tensorflow.Gradients; using Tensorflow.Gradients;


@@ -7,74 +9,57 @@ namespace Tensorflow
{ {
public partial class ops public partial class ops
{ {
static Dictionary<string, Func<Operation, Tensor[], Tensor[]>> gradientFunctions = null;

/// <summary>
/// Regiter new gradient function
/// </summary>
/// <param name="name">operation type</param>
/// <param name="func">function delegate</param>
public static void RegisterGradientFunction(string name, Func<Operation, Tensor[], Tensor[]> func)
{
if(gradientFunctions == null)
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>();

gradientFunctions[name] = func;
}

public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op) public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op)
{ {
if (op.inputs == null) return null; if (op.inputs == null) return null;


// map tensorflow\python\ops\math_grad.py
return (oper, out_grads) =>
if (gradientFunctions == null)
{ {
// Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'");
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>();


switch (oper.type)
var gradGroups = Assembly.GetExecutingAssembly()
.GetTypes()
.Where(x => x.GetCustomAttribute<RegisterGradient>() != null)
.ToArray();

foreach (var g in gradGroups)
{ {
case "Add":
return math_grad._AddGrad(oper, out_grads);
case "BiasAdd":
return nn_grad._BiasAddGrad(oper, out_grads);
case "ConcatV2":
return array_grad._ConcatGradV2(oper, out_grads);
case "DivNoNan":
return math_grad._DivNoNanGrad(oper, out_grads);
case "Exp":
return math_grad._ExpGrad(oper, out_grads);
case "Identity":
return math_grad._IdGrad(oper, out_grads);
case "Log":
return math_grad._LogGrad(oper, out_grads);
case "MatMul":
return math_grad._MatMulGrad(oper, out_grads);
case "Merge":
return control_flow_grad._MergeGrad(oper, out_grads);
case "Mul":
return math_grad._MulGrad(oper, out_grads);
case "Mean":
return math_grad._MeanGrad(oper, out_grads);
case "Neg":
return math_grad._NegGrad(oper, out_grads);
case "Sum":
return math_grad._SumGrad(oper, out_grads);
case "Sub":
return math_grad._SubGrad(oper, out_grads);
case "Pow":
return math_grad._PowGrad(oper, out_grads);
case "RealDiv":
return math_grad._RealDivGrad(oper, out_grads);
case "Reshape":
return array_grad._ReshapeGrad(oper, out_grads);
case "Relu":
return nn_grad._ReluGrad(oper, out_grads);
case "Sigmoid":
return math_grad._SigmoidGrad(oper, out_grads);
case "Square":
return math_grad._SquareGrad(oper, out_grads);
case "Squeeze":
return array_grad._SqueezeGrad(oper, out_grads);
case "Softmax":
return nn_grad._SoftmaxGrad(oper, out_grads);
case "SoftmaxCrossEntropyWithLogits":
return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
case "SparseSoftmaxCrossEntropyWithLogits":
return nn_grad._SparseSoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
case "Transpose":
return array_grad._TransposeGrad(oper, out_grads);
case "TopK":
case "TopKV2":
return nn_grad._TopKGrad(oper, out_grads);
default:
throw new NotImplementedException($"get_gradient_function {oper.type}");
var methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterGradient>() != null)
.ToArray();

foreach (var m in methods)
{
RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name,
(oper, out_grads) =>
g.InvokeMember(m.Name,
BindingFlags.InvokeMethod,
null,
null,
args: new object[] { oper, out_grads }) as Tensor[]
);
}
} }
};
}

if (!gradientFunctions.ContainsKey(op.type))
throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}");

return gradientFunctions[op.type];
} }
} }
} }

+ 2
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -20,7 +20,8 @@ Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.8.1.0</AssemblyVersion> <AssemblyVersion>0.8.1.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.8: <PackageReleaseNotes>Changes since v0.8:


Removed global static graph instance.</PackageReleaseNotes>
1. Removed global static graph instance.
2. Provide custom gradient function.</PackageReleaseNotes>
<LangVersion>7.2</LangVersion> <LangVersion>7.2</LangVersion>
<FileVersion>0.8.1.0</FileVersion> <FileVersion>0.8.1.0</FileVersion>
</PropertyGroup> </PropertyGroup>


Loading…
Cancel
Save