Browse Source

RegisterNoGradient, LookupError

tags/v0.12
Oceania2018 6 years ago
parent
commit
75ae2e9e09
14 changed files with 279 additions and 27 deletions
  1. +6
    -4
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +17
    -0
      src/TensorFlowNET.Core/Exceptions/LookupError.cs
  3. +33
    -0
      src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs
  4. +36
    -7
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  5. +8
    -7
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  6. +16
    -1
      src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Operations/nn_impl.py.cs
  8. +4
    -3
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  9. +1
    -1
      src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj
  10. +154
    -0
      test/TensorFlowNET.Examples/NeuralNetworks/FullyConnected.cs
  11. +0
    -0
      test/TensorFlowNET.Examples/NeuralNetworks/NeuralNetXor.cs
  12. +1
    -1
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj
  13. +1
    -1
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  14. +1
    -1
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

+ 6
- 4
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -431,17 +431,19 @@ namespace Tensorflow
/// <param name="input"></param> /// <param name="input"></param>
/// <param name="axis"></param> /// <param name="axis"></param>
/// <returns></returns> /// <returns></returns>
public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null)
public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null,
bool keepdims = false, string name = null)
{ {
if(!axis.HasValue && reduction_indices.HasValue) if(!axis.HasValue && reduction_indices.HasValue)
return math_ops.reduce_sum(input, reduction_indices.Value); return math_ops.reduce_sum(input, reduction_indices.Value);
else if (axis.HasValue && !reduction_indices.HasValue) else if (axis.HasValue && !reduction_indices.HasValue)
return math_ops.reduce_sum(input, axis.Value); return math_ops.reduce_sum(input, axis.Value);
return math_ops.reduce_sum(input);
return math_ops.reduce_sum(input, keepdims: keepdims, name: name);
} }


public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null)
=> math_ops.reduce_sum(input, axis);
public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null,
bool keepdims = false, string name = null)
=> math_ops.reduce_sum(input, axis, keepdims: keepdims, name: name);


/// <summary> /// <summary>
/// Computes the maximum of elements across dimensions of a tensor. /// Computes the maximum of elements across dimensions of a tensor.


+ 17
- 0
src/TensorFlowNET.Core/Exceptions/LookupError.cs View File

@@ -0,0 +1,17 @@
using System;

namespace Tensorflow
{
public class LookupError : TensorflowException
{
public LookupError() : base()
{

}

public LookupError(string message) : base(message)
{

}
}
}

+ 33
- 0
src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs View File

@@ -0,0 +1,33 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;

namespace Tensorflow.Gradients
{
/// <summary>
/// REGISTER_NO_GRADIENT_OP("");
/// </summary>
public class RegisterNoGradient : Attribute
{
public string Name { get; set; }

public RegisterNoGradient(string name)
{
Name = name;
}
}
}

+ 36
- 7
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -117,19 +117,44 @@ namespace Tensorflow
Tensor[] in_grads = null; Tensor[] in_grads = null;
var is_partitioned_call = _IsPartitionedCall(op); var is_partitioned_call = _IsPartitionedCall(op);
var is_func_call = false; var is_func_call = false;
var has_out_grads = true;
var has_out_grads = out_grads.Exists(x => x != null);
if (has_out_grads && !stop_ops.Contains(op)) if (has_out_grads && !stop_ops.Contains(op))
{ {
// A grad_fn must be defined, either as a function or as None // A grad_fn must be defined, either as a function or as None
// for ops that do not have gradients. // for ops that do not have gradients.
var grad_fn = ops.get_gradient_function(op);


if (is_func_call)
Func<Operation, Tensor[], Tensor[]> grad_fn = null;
try
{ {
grad_fn = ops.get_gradient_function(op);
}
catch (LookupError)
{
if (is_func_call)
{
if (is_partitioned_call)
{


}
else
{

}
}
else
{
throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})");
}
} }
else

// if (loop_state)
//loop_state.EnterGradWhileContext(op, before: false);

if ((is_func_call || grad_fn != null) && has_out_grads)
{ {
// NOTE: If _AggregatedGrads didn't compute a value for the i'th
// output, it means that the cost does not depend on output[i],
// therefore dC/doutput[i] is 0.
foreach (var (i, out_grad) in enumerate(out_grads)) foreach (var (i, out_grad) in enumerate(out_grads))
{ {
if (out_grad == null) if (out_grad == null)
@@ -143,13 +168,11 @@ namespace Tensorflow


tf_with(ops.name_scope(op.name + "_grad"), scope1 => tf_with(ops.name_scope(op.name + "_grad"), scope1 =>
{ {
string name1 = scope1;
if (grad_fn != null) if (grad_fn != null)
{ {
in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn); in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn);
_VerifyGeneratedGradients(in_grads, op);
} }
_VerifyGeneratedGradients(in_grads, op);
if (gate_gradients && in_grads.Count(x => x != null) > 1) if (gate_gradients && in_grads.Count(x => x != null) > 1)
{ {
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true);
@@ -157,6 +180,12 @@ namespace Tensorflow
} }
}); });
} }
else
{
// If no grad_fn is defined or none of out_grads is available,
// just propagate a list of None backwards.
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
}
} }
else else
{ {


+ 8
- 7
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -96,12 +96,11 @@ namespace Tensorflow.Gradients
}); });
} }


[RegisterGradient("GreaterEqual")]
public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
throw new NotImplementedException("_GreaterEqualGrad");
}
[RegisterNoGradient("GreaterEqual")]
public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null;

[RegisterNoGradient("ZerosLike")]
public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null;


[RegisterGradient("Identity")] [RegisterGradient("Identity")]
public static Tensor[] _IdGrad(Operation op, Tensor[] grads) public static Tensor[] _IdGrad(Operation op, Tensor[] grads)
@@ -415,7 +414,9 @@ namespace Tensorflow.Gradients
var rank = input_0_shape.Length; var rank = input_0_shape.Length;
if (Enumerable.SequenceEqual(Enumerable.Range(0, rank), axes.Data<int>())) if (Enumerable.SequenceEqual(Enumerable.Range(0, rank), axes.Data<int>()))
{ {
grad = array_ops.reshape(grad, new int[] { 1 });
var new_shape = range(rank).Select(x => 1).ToArray();
grad = array_ops.reshape(grad, new_shape);
// If shape is not fully defined (but rank is), we use Shape.
if (!input_0_shape.Contains(-1)) if (!input_0_shape.Contains(-1))
input_shape = constant_op.constant(input_0_shape); input_shape = constant_op.constant(input_0_shape);
else else


+ 16
- 1
src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs View File

@@ -39,6 +39,14 @@ namespace Tensorflow
gradientFunctions[name] = func; gradientFunctions[name] = func;
} }


public static void RegisterNoGradientFunction(string name)
{
if (gradientFunctions == null)
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>();

gradientFunctions[name] = null;
}

public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op) public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op)
{ {
if (op.inputs == null) return null; if (op.inputs == null) return null;
@@ -68,11 +76,18 @@ namespace Tensorflow
args: new object[] { oper, out_grads }) as Tensor[] args: new object[] { oper, out_grads }) as Tensor[]
); );
} }

// REGISTER_NO_GRADIENT_OP
methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterNoGradient>() != null)
.ToArray();

foreach (var m in methods)
RegisterNoGradientFunction(m.GetCustomAttribute<RegisterNoGradient>().Name);
} }
} }


if (!gradientFunctions.ContainsKey(op.type)) if (!gradientFunctions.ContainsKey(op.type))
throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}");
throw new LookupError($"can't get graident function through get_gradient_function {op.type}");


return gradientFunctions[op.type]; return gradientFunctions[op.type];
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/nn_impl.py.cs View File

@@ -154,7 +154,7 @@ namespace Tensorflow


public static Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null) public static Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null)
{ {
return tf_with(ops.name_scope(name, "", new { }), scope =>
return tf_with(ops.name_scope(name, "logistic_loss", new { logits, labels }), scope =>
{ {
name = scope; name = scope;
logits = ops.convert_to_tensor(logits, name: "logits"); logits = ops.convert_to_tensor(logits, name: "logits");


+ 4
- 3
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -19,7 +19,7 @@
Docs: https://tensorflownet.readthedocs.io</Description> Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.11.4.0</AssemblyVersion> <AssemblyVersion>0.11.4.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.10.0: <PackageReleaseNotes>Changes since v0.10.0:
1. Upgrade NumSharp to v0.20.
1. Upgrade NumSharp to v0.20.3.
2. Add DisposableObject class to manage object lifetime. 2. Add DisposableObject class to manage object lifetime.
3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables.
4. Change tensorflow to non-static class in order to execute some initialization process. 4. Change tensorflow to non-static class in order to execute some initialization process.
@@ -28,7 +28,8 @@ Docs: https://tensorflownet.readthedocs.io</Description>
7. Add tf.image related APIs. 7. Add tf.image related APIs.
8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. 8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor.
9. MultiThread is safe. 9. MultiThread is safe.
10. Support n-dim indexing for tensor.</PackageReleaseNotes>
10. Support n-dim indexing for tensor.
11. Add RegisterNoGradient</PackageReleaseNotes>
<LangVersion>7.3</LangVersion> <LangVersion>7.3</LangVersion>
<FileVersion>0.11.4.0</FileVersion> <FileVersion>0.11.4.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
@@ -62,7 +63,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>


<ItemGroup> <ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.5.1" /> <PackageReference Include="Google.Protobuf" Version="3.5.1" />
<PackageReference Include="NumSharp" Version="0.20.2" />
<PackageReference Include="NumSharp" Version="0.20.3" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


+ 1
- 1
src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj View File

@@ -2,7 +2,7 @@
<PropertyGroup> <PropertyGroup>
<OutputType>Exe</OutputType> <OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.2</TargetFramework>
<TargetFramework>netcoreapp3.0</TargetFramework>
<NoWin32Manifest>true</NoWin32Manifest> <NoWin32Manifest>true</NoWin32Manifest>
<AssemblyName>TensorFlowBenchmark</AssemblyName> <AssemblyName>TensorFlowBenchmark</AssemblyName>
<RootNamespace>TensorFlowBenchmark</RootNamespace> <RootNamespace>TensorFlowBenchmark</RootNamespace>


+ 154
- 0
test/TensorFlowNET.Examples/NeuralNetworks/FullyConnected.cs View File

@@ -0,0 +1,154 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.Examples
{
/// <summary>
/// How to optimise your input pipeline with queues and multi-threading
/// https://blog.metaflow.fr/tensorflow-how-to-optimise-your-input-pipeline-with-queues-and-multi-threading-e7c3874157e0
/// </summary>
public class FullyConnected : IExample
{
public bool Enabled { get; set; } = true;
public bool IsImportingGraph { get; set; }

public string Name => "Fully Connected Neural Network";

Tensor input = null;
Tensor x_inputs_data = null;
Tensor y_inputs_data = null;
Tensor accuracy = null;
Tensor y_true = null;
Tensor loss_op = null;
Operation train_op = null;

public Graph BuildGraph()
{
var g = tf.get_default_graph();
// batches of 128 samples, each containing 1024 data points
x_inputs_data = tf.random_normal(new[] { 128, 1024 }, mean: 0, stddev: 1);
// We will try to predict this law:
// predict 1 if the sum of the elements is positive and 0 otherwise
y_inputs_data = tf.cast(tf.reduce_sum(x_inputs_data, axis: 1, keepdims: true) > 0, tf.int32);
Tensor z = null;

tf_with(tf.variable_scope("placeholder"), delegate
{
input = tf.placeholder(tf.float32, shape: (-1, 1024));
y_true = tf.placeholder(tf.int32, shape: (-1, 1));
});

tf_with(tf.variable_scope("FullyConnected"), delegate
{
var w = tf.get_variable("w", shape: (1024, 1024), initializer: tf.random_normal_initializer(stddev: 0.1f));
var b = tf.get_variable("b", shape: 1024, initializer: tf.constant_initializer(0.1));
z = tf.matmul(input, w) + b;
var y = tf.nn.relu(z);

var w2 = tf.get_variable("w2", shape: (1024, 1), initializer: tf.random_normal_initializer(stddev: 0.1f));
var b2 = tf.get_variable("b2", shape: 1, initializer: tf.constant_initializer(0.1));
z = tf.matmul(y, w2) + b2;
});

tf_with(tf.variable_scope("Loss"), delegate
{
var losses = tf.nn.sigmoid_cross_entropy_with_logits(tf.cast(y_true, tf.float32), z);
loss_op = tf.reduce_mean(losses);
});

tf_with(tf.variable_scope("Accuracy"), delegate
{
var y_pred = tf.cast(z > 0, tf.int32);
accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred, y_true), tf.float32));
// accuracy = tf.Print(accuracy, data =[accuracy], message = "accuracy:")
});

// We add the training operation, ...
var adam = tf.train.AdamOptimizer(0.01f);
train_op = adam.minimize(loss_op, name: "train_op");

return g;
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public void Predict(Session sess)
{
throw new NotImplementedException();
}

public void PrepareData()
{
throw new NotImplementedException();
}

public bool Run()
{
var g = BuildGraph();
using (var sess = tf.Session())
Train(sess);
return true;
}

public void Test(Session sess)
{
throw new NotImplementedException();
}

public void Train(Session sess)
{
var sw = new Stopwatch();
sw.Start();
// init variables
sess.run(tf.global_variables_initializer());

// check the accuracy before training
var (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data));
sess.run(accuracy, (input, x_input), (y_true, y_input));

// training
foreach (var i in range(5000))
{
// by sampling some input data (fetching)
(x_input, y_input) = sess.run((x_inputs_data, y_inputs_data));
var (_, loss) = sess.run((train_op, loss_op), (input, x_input), (y_true, y_input));

// We regularly check the loss
if (i % 500 == 0)
print($"iter:{i} - loss:{loss}");
}

// Finally, we check our final accuracy
(x_input, y_input) = sess.run((x_inputs_data, y_inputs_data));
sess.run(accuracy, (input, x_input), (y_true, y_input));

print($"Time taken: {sw.Elapsed.TotalSeconds}s");
}
}
}

test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs → test/TensorFlowNET.Examples/NeuralNetworks/NeuralNetXor.cs View File


+ 1
- 1
test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj View File

@@ -2,7 +2,7 @@


<PropertyGroup> <PropertyGroup>
<OutputType>Exe</OutputType> <OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.2</TargetFramework>
<TargetFramework>netcoreapp3.0</TargetFramework>
<GeneratePackageOnBuild>false</GeneratePackageOnBuild> <GeneratePackageOnBuild>false</GeneratePackageOnBuild>
</PropertyGroup> </PropertyGroup>




+ 1
- 1
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -2,7 +2,7 @@


<PropertyGroup> <PropertyGroup>
<OutputType>Exe</OutputType> <OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.2</TargetFramework>
<TargetFramework>netcoreapp3.0</TargetFramework>
<GeneratePackageOnBuild>false</GeneratePackageOnBuild> <GeneratePackageOnBuild>false</GeneratePackageOnBuild>
</PropertyGroup> </PropertyGroup>




+ 1
- 1
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk"> <Project Sdk="Microsoft.NET.Sdk">


<PropertyGroup> <PropertyGroup>
<TargetFramework>netcoreapp2.2</TargetFramework>
<TargetFramework>netcoreapp3.0</TargetFramework>


<IsPackable>false</IsPackable> <IsPackable>false</IsPackable>




Loading…
Cancel
Save