diff --git a/docs/source/Gradient.md b/docs/source/Gradient.md
index 1c63a1c0..818ec73e 100644
--- a/docs/source/Gradient.md
+++ b/docs/source/Gradient.md
@@ -1,2 +1,15 @@
# Chapter. Gradient
+### Register custom gradient function
+
+TF.NET is extensible which can be added custom gradient function.
+
+```csharp
+// define gradient function
+ops.RegisterGradientFunction("ConcatV2", (oper, out_grads) =>
+{
+ var grad = grads[0];
+ return new Tensor[]{ };
+});
+```
+
diff --git a/src/TensorFlowNET.Core/Gradients/RegisterGradient.cs b/src/TensorFlowNET.Core/Gradients/RegisterGradient.cs
new file mode 100644
index 00000000..f07c613d
--- /dev/null
+++ b/src/TensorFlowNET.Core/Gradients/RegisterGradient.cs
@@ -0,0 +1,16 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Gradients
+{
+ public class RegisterGradient : Attribute
+ {
+ public string Name { get; set; }
+
+ public RegisterGradient(string name)
+ {
+ Name = name;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs
index 4e5b0d89..b7c5494a 100644
--- a/src/TensorFlowNET.Core/Gradients/array_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs
@@ -10,8 +10,10 @@ namespace Tensorflow.Gradients
///
/// tensorflow\python\ops\array_grad.py
///
+ [RegisterGradient("array_grad")]
public class array_grad
{
+ [RegisterGradient("ConcatV2")]
public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -123,12 +125,13 @@ namespace Tensorflow.Gradients
return gen_ops.shape_n(inputs);
}
-
+ [RegisterGradient("Reshape")]
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
}
+ [RegisterGradient("Squeeze")]
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { _ReshapeToInput(op, grads[0]) };
@@ -139,6 +142,7 @@ namespace Tensorflow.Gradients
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
}
+ [RegisterGradient("Transpose")]
public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
{
var p = op.inputs[1];
diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
index de61e52b..ec2a16a4 100644
--- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
+++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
@@ -69,11 +69,12 @@ namespace Tensorflow.Gradients
// false_grad = switch(grad[0], op.inputs[1])[0]
// true_grad = switch(grad[1], op.inputs[1])[1]
// return merge([false_grad, true_grad])[0], None
- }
-
+ }
+
///
/// Gradients for a Merge op are calculated using a Switch op.
///
+ [RegisterGradient("Merge")]
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs
index 5b5f6d4c..3f4ab94d 100644
--- a/src/TensorFlowNET.Core/Gradients/math_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs
@@ -10,8 +10,10 @@ namespace Tensorflow.Gradients
///
/// Gradients for operators defined in math_ops.py.
///
+ [RegisterGradient("math_grad")]
public class math_grad
{
+ [RegisterGradient("Add")]
public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
{
var x = op.inputs[0];
@@ -32,6 +34,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { r1, r2 };
}
+ [RegisterGradient("DivNoNan")]
public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -59,6 +62,7 @@ namespace Tensorflow.Gradients
///
///
///
+ [RegisterGradient("Exp")]
public static Tensor[] _ExpGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -69,11 +73,13 @@ namespace Tensorflow.Gradients
});
}
+ [RegisterGradient("Identity")]
public static Tensor[] _IdGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { grads[0] };
}
+ [RegisterGradient("Log")]
public static Tensor[] _LogGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -84,6 +90,7 @@ namespace Tensorflow.Gradients
});
}
+ [RegisterGradient("Mul")]
public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
{
var x = op.inputs[0];
@@ -112,6 +119,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { reshape1, reshape2 };
}
+ [RegisterGradient("MatMul")]
public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -145,6 +153,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { grad_a, grad_b };
}
+ [RegisterGradient("Mean")]
public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -159,6 +168,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null };
}
+ [RegisterGradient("Neg")]
public static Tensor[] _NegGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { -grads[0] };
@@ -169,6 +179,7 @@ namespace Tensorflow.Gradients
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1));
}
+ [RegisterGradient("Sub")]
public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -198,6 +209,7 @@ namespace Tensorflow.Gradients
!x_shape.Contains(-1);
}
+ [RegisterGradient("Sum")]
public static Tensor[] _SumGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -231,6 +243,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null };
}
+ [RegisterGradient("RealDiv")]
public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -254,6 +267,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { reshape2, reshape1 };
}
+ [RegisterGradient("Sigmoid")]
public static Tensor[] _SigmoidGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -266,6 +280,7 @@ namespace Tensorflow.Gradients
});
}
+ [RegisterGradient("Square")]
public static Tensor[] _SquareGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -279,6 +294,7 @@ namespace Tensorflow.Gradients
});
}
+ [RegisterGradient("Pow")]
public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
index a28d1bc5..b7d46b2c 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
@@ -9,6 +9,7 @@ namespace Tensorflow.Gradients
///
///
///
+ [RegisterGradient("math_grad")]
public class nn_grad
{
///
@@ -17,6 +18,7 @@ namespace Tensorflow.Gradients
///
///
///
+ [RegisterGradient("BiasAdd")]
public static Tensor[] _BiasAddGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -25,6 +27,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { grad, bias_add_grad };
}
+ [RegisterGradient("Relu")]
public static Tensor[] _ReluGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { gen_nn_ops.relu_grad(grads[0], op.outputs[0]) };
@@ -36,6 +39,7 @@ namespace Tensorflow.Gradients
///
///
///
+ [RegisterGradient("Softmax")]
public static Tensor[] _SoftmaxGrad(Operation op, Tensor[] grads)
{
var grad_softmax = grads[0];
@@ -54,6 +58,7 @@ namespace Tensorflow.Gradients
///
///
///
+ [RegisterGradient("SoftmaxCrossEntropyWithLogits")]
public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads)
{
var grad_loss = grads[0];
@@ -74,6 +79,7 @@ namespace Tensorflow.Gradients
};
}
+ [RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")]
public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads)
{
var sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
@@ -111,6 +117,7 @@ namespace Tensorflow.Gradients
///
///
///
+ [RegisterGradient("TopK")]
public static Tensor[] _TopKGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
index 98574339..477a39ff 100644
--- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
+++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
@@ -1,5 +1,7 @@
using System;
using System.Collections.Generic;
+using System.Linq;
+using System.Reflection;
using System.Text;
using Tensorflow.Gradients;
@@ -7,74 +9,57 @@ namespace Tensorflow
{
public partial class ops
{
+ static Dictionary> gradientFunctions = null;
+
+ ///
+ /// Regiter new gradient function
+ ///
+ /// operation type
+ /// function delegate
+ public static void RegisterGradientFunction(string name, Func func)
+ {
+ if(gradientFunctions == null)
+ gradientFunctions = new Dictionary>();
+
+ gradientFunctions[name] = func;
+ }
+
public static Func get_gradient_function(Operation op)
{
if (op.inputs == null) return null;
- // map tensorflow\python\ops\math_grad.py
- return (oper, out_grads) =>
+ if (gradientFunctions == null)
{
- // Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'");
+ gradientFunctions = new Dictionary>();
- switch (oper.type)
+ var gradGroups = Assembly.GetExecutingAssembly()
+ .GetTypes()
+ .Where(x => x.GetCustomAttribute() != null)
+ .ToArray();
+
+ foreach (var g in gradGroups)
{
- case "Add":
- return math_grad._AddGrad(oper, out_grads);
- case "BiasAdd":
- return nn_grad._BiasAddGrad(oper, out_grads);
- case "ConcatV2":
- return array_grad._ConcatGradV2(oper, out_grads);
- case "DivNoNan":
- return math_grad._DivNoNanGrad(oper, out_grads);
- case "Exp":
- return math_grad._ExpGrad(oper, out_grads);
- case "Identity":
- return math_grad._IdGrad(oper, out_grads);
- case "Log":
- return math_grad._LogGrad(oper, out_grads);
- case "MatMul":
- return math_grad._MatMulGrad(oper, out_grads);
- case "Merge":
- return control_flow_grad._MergeGrad(oper, out_grads);
- case "Mul":
- return math_grad._MulGrad(oper, out_grads);
- case "Mean":
- return math_grad._MeanGrad(oper, out_grads);
- case "Neg":
- return math_grad._NegGrad(oper, out_grads);
- case "Sum":
- return math_grad._SumGrad(oper, out_grads);
- case "Sub":
- return math_grad._SubGrad(oper, out_grads);
- case "Pow":
- return math_grad._PowGrad(oper, out_grads);
- case "RealDiv":
- return math_grad._RealDivGrad(oper, out_grads);
- case "Reshape":
- return array_grad._ReshapeGrad(oper, out_grads);
- case "Relu":
- return nn_grad._ReluGrad(oper, out_grads);
- case "Sigmoid":
- return math_grad._SigmoidGrad(oper, out_grads);
- case "Square":
- return math_grad._SquareGrad(oper, out_grads);
- case "Squeeze":
- return array_grad._SqueezeGrad(oper, out_grads);
- case "Softmax":
- return nn_grad._SoftmaxGrad(oper, out_grads);
- case "SoftmaxCrossEntropyWithLogits":
- return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
- case "SparseSoftmaxCrossEntropyWithLogits":
- return nn_grad._SparseSoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
- case "Transpose":
- return array_grad._TransposeGrad(oper, out_grads);
- case "TopK":
- case "TopKV2":
- return nn_grad._TopKGrad(oper, out_grads);
- default:
- throw new NotImplementedException($"get_gradient_function {oper.type}");
+ var methods = g.GetMethods().Where(x => x.GetCustomAttribute() != null)
+ .ToArray();
+
+ foreach (var m in methods)
+ {
+ RegisterGradientFunction(m.GetCustomAttribute().Name,
+ (oper, out_grads) =>
+ g.InvokeMember(m.Name,
+ BindingFlags.InvokeMethod,
+ null,
+ null,
+ args: new object[] { oper, out_grads }) as Tensor[]
+ );
+ }
}
- };
+ }
+
+ if (!gradientFunctions.ContainsKey(op.type))
+ throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}");
+
+ return gradientFunctions[op.type];
}
}
}
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 86c53286..eef35f50 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -20,7 +20,8 @@ Docs: https://tensorflownet.readthedocs.io
0.8.1.0
Changes since v0.8:
-Removed global static graph instance.
+1. Removed global static graph instance.
+2. Provide custom gradient function.
7.2
0.8.1.0