diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index 4891fcbb..b479ba0b 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -26,31 +26,8 @@ namespace Tensorflow { static Dictionary> gradientFunctions = null; - /// - /// Regiter new gradient function - /// - /// operation type - /// function delegate - public static void RegisterGradientFunction(string name, Func func) + private static void RegisterFromAssembly() { - if(gradientFunctions == null) - gradientFunctions = new Dictionary>(); - - gradientFunctions[name] = func; - } - - public static void RegisterNoGradientFunction(string name) - { - if (gradientFunctions == null) - gradientFunctions = new Dictionary>(); - - gradientFunctions[name] = null; - } - - public static Func get_gradient_function(Operation op) - { - if (op.inputs == null) return null; - if (gradientFunctions == null) { gradientFunctions = new Dictionary>(); @@ -62,7 +39,8 @@ namespace Tensorflow foreach (var g in gradGroups) { - var methods = g.GetMethods().Where(x => x.GetCustomAttribute() != null) + var methods = g.GetMethods() + .Where(x => x.GetCustomAttribute() != null) .ToArray(); foreach (var m in methods) @@ -78,13 +56,40 @@ namespace Tensorflow } // REGISTER_NO_GRADIENT_OP - methods = g.GetMethods().Where(x => x.GetCustomAttribute() != null) + methods = g.GetMethods() + .Where(x => x.GetCustomAttribute() != null) .ToArray(); foreach (var m in methods) RegisterNoGradientFunction(m.GetCustomAttribute().Name); } } + } + + /// + /// Regiter new gradient function + /// + /// operation type + /// function delegate + public static void RegisterGradientFunction(string name, Func func) + { + RegisterFromAssembly(); + + gradientFunctions[name] = func; + } + + public static void RegisterNoGradientFunction(string name) + { + RegisterFromAssembly(); + + gradientFunctions[name] = null; + } + + public static Func get_gradient_function(Operation op) + { + if (op.inputs == null) return null; + + RegisterFromAssembly(); if (!gradientFunctions.ContainsKey(op.type)) throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index 3cb7ccb2..a2958d16 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -65,7 +65,7 @@ https://tensorflownet.readthedocs.io - +