From f13e35d760a380a79ae0b0651c08df6348afa9e0 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 6 Jun 2019 23:19:11 -0500 Subject: [PATCH] extend gradient function capability. --- docs/source/Gradient.md | 13 +++ .../Gradients/RegisterGradient.cs | 16 +++ .../Gradients/array_grad.cs | 6 +- .../Gradients/control_flow_grad.py.cs | 5 +- src/TensorFlowNET.Core/Gradients/math_grad.cs | 16 +++ src/TensorFlowNET.Core/Gradients/nn_grad.cs | 7 ++ .../ops.gradient_function_mapping.cs | 105 ++++++++---------- .../TensorFlowNET.Core.csproj | 3 +- 8 files changed, 107 insertions(+), 64 deletions(-) create mode 100644 src/TensorFlowNET.Core/Gradients/RegisterGradient.cs diff --git a/docs/source/Gradient.md b/docs/source/Gradient.md index 1c63a1c0..818ec73e 100644 --- a/docs/source/Gradient.md +++ b/docs/source/Gradient.md @@ -1,2 +1,15 @@ # Chapter. Gradient +### Register custom gradient function + +TF.NET is extensible which can be added custom gradient function. + +```csharp +// define gradient function +ops.RegisterGradientFunction("ConcatV2", (oper, out_grads) => +{ + var grad = grads[0]; + return new Tensor[]{ }; +}); +``` + diff --git a/src/TensorFlowNET.Core/Gradients/RegisterGradient.cs b/src/TensorFlowNET.Core/Gradients/RegisterGradient.cs new file mode 100644 index 00000000..f07c613d --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/RegisterGradient.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Gradients +{ + public class RegisterGradient : Attribute + { + public string Name { get; set; } + + public RegisterGradient(string name) + { + Name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 4e5b0d89..b7c5494a 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -10,8 +10,10 @@ namespace Tensorflow.Gradients /// /// tensorflow\python\ops\array_grad.py /// + [RegisterGradient("array_grad")] public class array_grad { + [RegisterGradient("ConcatV2")] public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -123,12 +125,13 @@ namespace Tensorflow.Gradients return gen_ops.shape_n(inputs); } - + [RegisterGradient("Reshape")] public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads) { return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; } + [RegisterGradient("Squeeze")] public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) { return new Tensor[] { _ReshapeToInput(op, grads[0]) }; @@ -139,6 +142,7 @@ namespace Tensorflow.Gradients return array_ops.reshape(grad, array_ops.shape(op.inputs[0])); } + [RegisterGradient("Transpose")] public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads) { var p = op.inputs[1]; diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs index de61e52b..ec2a16a4 100644 --- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs @@ -69,11 +69,12 @@ namespace Tensorflow.Gradients // false_grad = switch(grad[0], op.inputs[1])[0] // true_grad = switch(grad[1], op.inputs[1])[1] // return merge([false_grad, true_grad])[0], None - } - + } + /// /// Gradients for a Merge op are calculated using a Switch op. /// + [RegisterGradient("Merge")] public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) { var grad = grads[0]; diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 5b5f6d4c..3f4ab94d 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -10,8 +10,10 @@ namespace Tensorflow.Gradients /// /// Gradients for operators defined in math_ops.py. /// + [RegisterGradient("math_grad")] public class math_grad { + [RegisterGradient("Add")] public static Tensor[] _AddGrad(Operation op, Tensor[] grads) { var x = op.inputs[0]; @@ -32,6 +34,7 @@ namespace Tensorflow.Gradients return new Tensor[] { r1, r2 }; } + [RegisterGradient("DivNoNan")] public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -59,6 +62,7 @@ namespace Tensorflow.Gradients /// /// /// + [RegisterGradient("Exp")] public static Tensor[] _ExpGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -69,11 +73,13 @@ namespace Tensorflow.Gradients }); } + [RegisterGradient("Identity")] public static Tensor[] _IdGrad(Operation op, Tensor[] grads) { return new Tensor[] { grads[0] }; } + [RegisterGradient("Log")] public static Tensor[] _LogGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -84,6 +90,7 @@ namespace Tensorflow.Gradients }); } + [RegisterGradient("Mul")] public static Tensor[] _MulGrad(Operation op, Tensor[] grads) { var x = op.inputs[0]; @@ -112,6 +119,7 @@ namespace Tensorflow.Gradients return new Tensor[] { reshape1, reshape2 }; } + [RegisterGradient("MatMul")] public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -145,6 +153,7 @@ namespace Tensorflow.Gradients return new Tensor[] { grad_a, grad_b }; } + [RegisterGradient("Mean")] public static Tensor[] _MeanGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -159,6 +168,7 @@ namespace Tensorflow.Gradients return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null }; } + [RegisterGradient("Neg")] public static Tensor[] _NegGrad(Operation op, Tensor[] grads) { return new Tensor[] { -grads[0] }; @@ -169,6 +179,7 @@ namespace Tensorflow.Gradients return math_ops.floordiv(x, gen_math_ops.maximum(y, 1)); } + [RegisterGradient("Sub")] public static Tensor[] _SubGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -198,6 +209,7 @@ namespace Tensorflow.Gradients !x_shape.Contains(-1); } + [RegisterGradient("Sum")] public static Tensor[] _SumGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -231,6 +243,7 @@ namespace Tensorflow.Gradients return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null }; } + [RegisterGradient("RealDiv")] public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -254,6 +267,7 @@ namespace Tensorflow.Gradients return new Tensor[] { reshape2, reshape1 }; } + [RegisterGradient("Sigmoid")] public static Tensor[] _SigmoidGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -266,6 +280,7 @@ namespace Tensorflow.Gradients }); } + [RegisterGradient("Square")] public static Tensor[] _SquareGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -279,6 +294,7 @@ namespace Tensorflow.Gradients }); } + [RegisterGradient("Pow")] public static Tensor[] _PowGrad(Operation op, Tensor[] grads) { var grad = grads[0]; diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index a28d1bc5..b7d46b2c 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -9,6 +9,7 @@ namespace Tensorflow.Gradients /// /// /// + [RegisterGradient("math_grad")] public class nn_grad { /// @@ -17,6 +18,7 @@ namespace Tensorflow.Gradients /// /// /// + [RegisterGradient("BiasAdd")] public static Tensor[] _BiasAddGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -25,6 +27,7 @@ namespace Tensorflow.Gradients return new Tensor[] { grad, bias_add_grad }; } + [RegisterGradient("Relu")] public static Tensor[] _ReluGrad(Operation op, Tensor[] grads) { return new Tensor[] { gen_nn_ops.relu_grad(grads[0], op.outputs[0]) }; @@ -36,6 +39,7 @@ namespace Tensorflow.Gradients /// /// /// + [RegisterGradient("Softmax")] public static Tensor[] _SoftmaxGrad(Operation op, Tensor[] grads) { var grad_softmax = grads[0]; @@ -54,6 +58,7 @@ namespace Tensorflow.Gradients /// /// /// + [RegisterGradient("SoftmaxCrossEntropyWithLogits")] public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads) { var grad_loss = grads[0]; @@ -74,6 +79,7 @@ namespace Tensorflow.Gradients }; } + [RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")] public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads) { var sparse_softmax_grad_without_gradient = array_ops.prevent_gradient( @@ -111,6 +117,7 @@ namespace Tensorflow.Gradients /// /// /// + [RegisterGradient("TopK")] public static Tensor[] _TopKGrad(Operation op, Tensor[] grads) { var grad = grads[0]; diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index 98574339..477a39ff 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -1,5 +1,7 @@ using System; using System.Collections.Generic; +using System.Linq; +using System.Reflection; using System.Text; using Tensorflow.Gradients; @@ -7,74 +9,57 @@ namespace Tensorflow { public partial class ops { + static Dictionary> gradientFunctions = null; + + /// + /// Regiter new gradient function + /// + /// operation type + /// function delegate + public static void RegisterGradientFunction(string name, Func func) + { + if(gradientFunctions == null) + gradientFunctions = new Dictionary>(); + + gradientFunctions[name] = func; + } + public static Func get_gradient_function(Operation op) { if (op.inputs == null) return null; - // map tensorflow\python\ops\math_grad.py - return (oper, out_grads) => + if (gradientFunctions == null) { - // Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'"); + gradientFunctions = new Dictionary>(); - switch (oper.type) + var gradGroups = Assembly.GetExecutingAssembly() + .GetTypes() + .Where(x => x.GetCustomAttribute() != null) + .ToArray(); + + foreach (var g in gradGroups) { - case "Add": - return math_grad._AddGrad(oper, out_grads); - case "BiasAdd": - return nn_grad._BiasAddGrad(oper, out_grads); - case "ConcatV2": - return array_grad._ConcatGradV2(oper, out_grads); - case "DivNoNan": - return math_grad._DivNoNanGrad(oper, out_grads); - case "Exp": - return math_grad._ExpGrad(oper, out_grads); - case "Identity": - return math_grad._IdGrad(oper, out_grads); - case "Log": - return math_grad._LogGrad(oper, out_grads); - case "MatMul": - return math_grad._MatMulGrad(oper, out_grads); - case "Merge": - return control_flow_grad._MergeGrad(oper, out_grads); - case "Mul": - return math_grad._MulGrad(oper, out_grads); - case "Mean": - return math_grad._MeanGrad(oper, out_grads); - case "Neg": - return math_grad._NegGrad(oper, out_grads); - case "Sum": - return math_grad._SumGrad(oper, out_grads); - case "Sub": - return math_grad._SubGrad(oper, out_grads); - case "Pow": - return math_grad._PowGrad(oper, out_grads); - case "RealDiv": - return math_grad._RealDivGrad(oper, out_grads); - case "Reshape": - return array_grad._ReshapeGrad(oper, out_grads); - case "Relu": - return nn_grad._ReluGrad(oper, out_grads); - case "Sigmoid": - return math_grad._SigmoidGrad(oper, out_grads); - case "Square": - return math_grad._SquareGrad(oper, out_grads); - case "Squeeze": - return array_grad._SqueezeGrad(oper, out_grads); - case "Softmax": - return nn_grad._SoftmaxGrad(oper, out_grads); - case "SoftmaxCrossEntropyWithLogits": - return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads); - case "SparseSoftmaxCrossEntropyWithLogits": - return nn_grad._SparseSoftmaxCrossEntropyWithLogitsGrad(oper, out_grads); - case "Transpose": - return array_grad._TransposeGrad(oper, out_grads); - case "TopK": - case "TopKV2": - return nn_grad._TopKGrad(oper, out_grads); - default: - throw new NotImplementedException($"get_gradient_function {oper.type}"); + var methods = g.GetMethods().Where(x => x.GetCustomAttribute() != null) + .ToArray(); + + foreach (var m in methods) + { + RegisterGradientFunction(m.GetCustomAttribute().Name, + (oper, out_grads) => + g.InvokeMember(m.Name, + BindingFlags.InvokeMethod, + null, + null, + args: new object[] { oper, out_grads }) as Tensor[] + ); + } } - }; + } + + if (!gradientFunctions.ContainsKey(op.type)) + throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}"); + + return gradientFunctions[op.type]; } } } diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 86c53286..eef35f50 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -20,7 +20,8 @@ Docs: https://tensorflownet.readthedocs.io 0.8.1.0 Changes since v0.8: -Removed global static graph instance. +1. Removed global static graph instance. +2. Provide custom gradient function. 7.2 0.8.1.0