|
|
@@ -26,31 +26,8 @@ namespace Tensorflow |
|
|
{ |
|
|
{ |
|
|
static Dictionary<string, Func<Operation, Tensor[], Tensor[]>> gradientFunctions = null; |
|
|
static Dictionary<string, Func<Operation, Tensor[], Tensor[]>> gradientFunctions = null; |
|
|
|
|
|
|
|
|
/// <summary> |
|
|
|
|
|
/// Regiter new gradient function |
|
|
|
|
|
/// </summary> |
|
|
|
|
|
/// <param name="name">operation type</param> |
|
|
|
|
|
/// <param name="func">function delegate</param> |
|
|
|
|
|
public static void RegisterGradientFunction(string name, Func<Operation, Tensor[], Tensor[]> func) |
|
|
|
|
|
|
|
|
private static void RegisterFromAssembly() |
|
|
{ |
|
|
{ |
|
|
if(gradientFunctions == null) |
|
|
|
|
|
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>(); |
|
|
|
|
|
|
|
|
|
|
|
gradientFunctions[name] = func; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public static void RegisterNoGradientFunction(string name) |
|
|
|
|
|
{ |
|
|
|
|
|
if (gradientFunctions == null) |
|
|
|
|
|
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>(); |
|
|
|
|
|
|
|
|
|
|
|
gradientFunctions[name] = null; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op) |
|
|
|
|
|
{ |
|
|
|
|
|
if (op.inputs == null) return null; |
|
|
|
|
|
|
|
|
|
|
|
if (gradientFunctions == null) |
|
|
if (gradientFunctions == null) |
|
|
{ |
|
|
{ |
|
|
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>(); |
|
|
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>(); |
|
|
@@ -62,7 +39,8 @@ namespace Tensorflow |
|
|
|
|
|
|
|
|
foreach (var g in gradGroups) |
|
|
foreach (var g in gradGroups) |
|
|
{ |
|
|
{ |
|
|
var methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterGradient>() != null) |
|
|
|
|
|
|
|
|
var methods = g.GetMethods() |
|
|
|
|
|
.Where(x => x.GetCustomAttribute<RegisterGradient>() != null) |
|
|
.ToArray(); |
|
|
.ToArray(); |
|
|
|
|
|
|
|
|
foreach (var m in methods) |
|
|
foreach (var m in methods) |
|
|
@@ -78,13 +56,40 @@ namespace Tensorflow |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// REGISTER_NO_GRADIENT_OP |
|
|
// REGISTER_NO_GRADIENT_OP |
|
|
methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterNoGradient>() != null) |
|
|
|
|
|
|
|
|
methods = g.GetMethods() |
|
|
|
|
|
.Where(x => x.GetCustomAttribute<RegisterNoGradient>() != null) |
|
|
.ToArray(); |
|
|
.ToArray(); |
|
|
|
|
|
|
|
|
foreach (var m in methods) |
|
|
foreach (var m in methods) |
|
|
RegisterNoGradientFunction(m.GetCustomAttribute<RegisterNoGradient>().Name); |
|
|
RegisterNoGradientFunction(m.GetCustomAttribute<RegisterNoGradient>().Name); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/// <summary> |
|
|
|
|
|
/// Regiter new gradient function |
|
|
|
|
|
/// </summary> |
|
|
|
|
|
/// <param name="name">operation type</param> |
|
|
|
|
|
/// <param name="func">function delegate</param> |
|
|
|
|
|
public static void RegisterGradientFunction(string name, Func<Operation, Tensor[], Tensor[]> func) |
|
|
|
|
|
{ |
|
|
|
|
|
RegisterFromAssembly(); |
|
|
|
|
|
|
|
|
|
|
|
gradientFunctions[name] = func; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public static void RegisterNoGradientFunction(string name) |
|
|
|
|
|
{ |
|
|
|
|
|
RegisterFromAssembly(); |
|
|
|
|
|
|
|
|
|
|
|
gradientFunctions[name] = null; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op) |
|
|
|
|
|
{ |
|
|
|
|
|
if (op.inputs == null) return null; |
|
|
|
|
|
|
|
|
|
|
|
RegisterFromAssembly(); |
|
|
|
|
|
|
|
|
if (!gradientFunctions.ContainsKey(op.type)) |
|
|
if (!gradientFunctions.ContainsKey(op.type)) |
|
|
throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); |
|
|
throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); |
|
|
|