Browse Source

Calling RegisterGradientFunction too early #464

tags/v0.20
Oceania2018 6 years ago
parent
commit
991b8702c3
2 changed files with 32 additions and 27 deletions
  1. +31
    -26
      src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
  2. +1
    -1
      src/TensorFlowNET.Core/TensorFlow.Binding.csproj

+ 31
- 26
src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs View File

@@ -26,31 +26,8 @@ namespace Tensorflow
{
static Dictionary<string, Func<Operation, Tensor[], Tensor[]>> gradientFunctions = null;

/// <summary>
/// Regiter new gradient function
/// </summary>
/// <param name="name">operation type</param>
/// <param name="func">function delegate</param>
public static void RegisterGradientFunction(string name, Func<Operation, Tensor[], Tensor[]> func)
private static void RegisterFromAssembly()
{
if(gradientFunctions == null)
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>();

gradientFunctions[name] = func;
}

public static void RegisterNoGradientFunction(string name)
{
if (gradientFunctions == null)
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>();

gradientFunctions[name] = null;
}

public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op)
{
if (op.inputs == null) return null;

if (gradientFunctions == null)
{
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>();
@@ -62,7 +39,8 @@ namespace Tensorflow

foreach (var g in gradGroups)
{
var methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterGradient>() != null)
var methods = g.GetMethods()
.Where(x => x.GetCustomAttribute<RegisterGradient>() != null)
.ToArray();

foreach (var m in methods)
@@ -78,13 +56,40 @@ namespace Tensorflow
}

// REGISTER_NO_GRADIENT_OP
methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterNoGradient>() != null)
methods = g.GetMethods()
.Where(x => x.GetCustomAttribute<RegisterNoGradient>() != null)
.ToArray();

foreach (var m in methods)
RegisterNoGradientFunction(m.GetCustomAttribute<RegisterNoGradient>().Name);
}
}
}

/// <summary>
/// Regiter new gradient function
/// </summary>
/// <param name="name">operation type</param>
/// <param name="func">function delegate</param>
public static void RegisterGradientFunction(string name, Func<Operation, Tensor[], Tensor[]> func)
{
RegisterFromAssembly();

gradientFunctions[name] = func;
}

public static void RegisterNoGradientFunction(string name)
{
RegisterFromAssembly();

gradientFunctions[name] = null;
}

public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op)
{
if (op.inputs == null) return null;

RegisterFromAssembly();

if (!gradientFunctions.ContainsKey(op.type))
throw new LookupError($"can't get graident function through get_gradient_function {op.type}");


+ 1
- 1
src/TensorFlowNET.Core/TensorFlow.Binding.csproj View File

@@ -65,7 +65,7 @@ https://tensorflownet.readthedocs.io</Description>
</ItemGroup>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.10.1" />
<PackageReference Include="Google.Protobuf" Version="3.11.2" />
<PackageReference Include="NumSharp" Version="0.20.4" />
</ItemGroup>



Loading…
Cancel
Save