From 75ae2e9e09d74683955478f0a6fe41eb8593a155 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 28 Sep 2019 09:55:44 -0500 Subject: [PATCH] RegisterNoGradient, LookupError --- src/TensorFlowNET.Core/APIs/tf.math.cs | 10 +- .../Exceptions/LookupError.cs | 17 ++ .../Gradients/RegisterNoGradient.cs | 33 ++++ .../Gradients/gradients_util.cs | 43 ++++- src/TensorFlowNET.Core/Gradients/math_grad.cs | 15 +- .../ops.gradient_function_mapping.cs | 17 +- .../Operations/nn_impl.py.cs | 2 +- .../TensorFlowNET.Core.csproj | 7 +- .../TensorFlowBenchmark.csproj | 2 +- .../NeuralNetworks/FullyConnected.cs | 154 ++++++++++++++++++ .../NeuralNetXor.cs | 0 .../TensorFlowNET.Examples.GPU.csproj | 2 +- .../TensorFlowNET.Examples.csproj | 2 +- .../TensorFlowNET.UnitTest.csproj | 2 +- 14 files changed, 279 insertions(+), 27 deletions(-) create mode 100644 src/TensorFlowNET.Core/Exceptions/LookupError.cs create mode 100644 src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs create mode 100644 test/TensorFlowNET.Examples/NeuralNetworks/FullyConnected.cs rename test/TensorFlowNET.Examples/{BasicModels => NeuralNetworks}/NeuralNetXor.cs (100%) diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index e62f6358..cba05d0e 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -431,17 +431,19 @@ namespace Tensorflow /// /// /// - public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null) + public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, + bool keepdims = false, string name = null) { if(!axis.HasValue && reduction_indices.HasValue) return math_ops.reduce_sum(input, reduction_indices.Value); else if (axis.HasValue && !reduction_indices.HasValue) return math_ops.reduce_sum(input, axis.Value); - return math_ops.reduce_sum(input); + return math_ops.reduce_sum(input, keepdims: keepdims, name: name); } - public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null) - => math_ops.reduce_sum(input, axis); + public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null, + bool keepdims = false, string name = null) + => math_ops.reduce_sum(input, axis, keepdims: keepdims, name: name); /// /// Computes the maximum of elements across dimensions of a tensor. diff --git a/src/TensorFlowNET.Core/Exceptions/LookupError.cs b/src/TensorFlowNET.Core/Exceptions/LookupError.cs new file mode 100644 index 00000000..ebbaa526 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/LookupError.cs @@ -0,0 +1,17 @@ +using System; + +namespace Tensorflow +{ + public class LookupError : TensorflowException + { + public LookupError() : base() + { + + } + + public LookupError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs b/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs new file mode 100644 index 00000000..d573e317 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs @@ -0,0 +1,33 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow.Gradients +{ + /// + /// REGISTER_NO_GRADIENT_OP(""); + /// + public class RegisterNoGradient : Attribute + { + public string Name { get; set; } + + public RegisterNoGradient(string name) + { + Name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 5aa0d044..bfa1d296 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -117,19 +117,44 @@ namespace Tensorflow Tensor[] in_grads = null; var is_partitioned_call = _IsPartitionedCall(op); var is_func_call = false; - var has_out_grads = true; + var has_out_grads = out_grads.Exists(x => x != null); if (has_out_grads && !stop_ops.Contains(op)) { // A grad_fn must be defined, either as a function or as None // for ops that do not have gradients. - var grad_fn = ops.get_gradient_function(op); - if (is_func_call) + Func grad_fn = null; + try { + grad_fn = ops.get_gradient_function(op); + } + catch (LookupError) + { + if (is_func_call) + { + if (is_partitioned_call) + { + } + else + { + + } + } + else + { + throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); + } } - else + + // if (loop_state) + //loop_state.EnterGradWhileContext(op, before: false); + + if ((is_func_call || grad_fn != null) && has_out_grads) { + // NOTE: If _AggregatedGrads didn't compute a value for the i'th + // output, it means that the cost does not depend on output[i], + // therefore dC/doutput[i] is 0. foreach (var (i, out_grad) in enumerate(out_grads)) { if (out_grad == null) @@ -143,13 +168,11 @@ namespace Tensorflow tf_with(ops.name_scope(op.name + "_grad"), scope1 => { - string name1 = scope1; if (grad_fn != null) { in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn); - _VerifyGeneratedGradients(in_grads, op); } - + _VerifyGeneratedGradients(in_grads, op); if (gate_gradients && in_grads.Count(x => x != null) > 1) { ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); @@ -157,6 +180,12 @@ namespace Tensorflow } }); } + else + { + // If no grad_fn is defined or none of out_grads is available, + // just propagate a list of None backwards. + in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; + } } else { diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 85f849cc..0e9e09e5 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -96,12 +96,11 @@ namespace Tensorflow.Gradients }); } - [RegisterGradient("GreaterEqual")] - public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) - { - var grad = grads[0]; - throw new NotImplementedException("_GreaterEqualGrad"); - } + [RegisterNoGradient("GreaterEqual")] + public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null; + + [RegisterNoGradient("ZerosLike")] + public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null; [RegisterGradient("Identity")] public static Tensor[] _IdGrad(Operation op, Tensor[] grads) @@ -415,7 +414,9 @@ namespace Tensorflow.Gradients var rank = input_0_shape.Length; if (Enumerable.SequenceEqual(Enumerable.Range(0, rank), axes.Data())) { - grad = array_ops.reshape(grad, new int[] { 1 }); + var new_shape = range(rank).Select(x => 1).ToArray(); + grad = array_ops.reshape(grad, new_shape); + // If shape is not fully defined (but rank is), we use Shape. if (!input_0_shape.Contains(-1)) input_shape = constant_op.constant(input_0_shape); else diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index 0b624ba1..4891fcbb 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -39,6 +39,14 @@ namespace Tensorflow gradientFunctions[name] = func; } + public static void RegisterNoGradientFunction(string name) + { + if (gradientFunctions == null) + gradientFunctions = new Dictionary>(); + + gradientFunctions[name] = null; + } + public static Func get_gradient_function(Operation op) { if (op.inputs == null) return null; @@ -68,11 +76,18 @@ namespace Tensorflow args: new object[] { oper, out_grads }) as Tensor[] ); } + + // REGISTER_NO_GRADIENT_OP + methods = g.GetMethods().Where(x => x.GetCustomAttribute() != null) + .ToArray(); + + foreach (var m in methods) + RegisterNoGradientFunction(m.GetCustomAttribute().Name); } } if (!gradientFunctions.ContainsKey(op.type)) - throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}"); + throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); return gradientFunctions[op.type]; } diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index 239454b6..bced0047 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -154,7 +154,7 @@ namespace Tensorflow public static Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null) { - return tf_with(ops.name_scope(name, "", new { }), scope => + return tf_with(ops.name_scope(name, "logistic_loss", new { logits, labels }), scope => { name = scope; logits = ops.convert_to_tensor(logits, name: "logits"); diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 3594649d..13f56a60 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -19,7 +19,7 @@ Docs: https://tensorflownet.readthedocs.io 0.11.4.0 Changes since v0.10.0: -1. Upgrade NumSharp to v0.20. +1. Upgrade NumSharp to v0.20.3. 2. Add DisposableObject class to manage object lifetime. 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. 4. Change tensorflow to non-static class in order to execute some initialization process. @@ -28,7 +28,8 @@ Docs: https://tensorflownet.readthedocs.io 7. Add tf.image related APIs. 8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. 9. MultiThread is safe. -10. Support n-dim indexing for tensor. +10. Support n-dim indexing for tensor. +11. Add RegisterNoGradient 7.3 0.11.4.0 LICENSE @@ -62,7 +63,7 @@ Docs: https://tensorflownet.readthedocs.io - + diff --git a/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj b/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj index 4618f06b..b0e91991 100644 --- a/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp2.2 + netcoreapp3.0 true TensorFlowBenchmark TensorFlowBenchmark diff --git a/test/TensorFlowNET.Examples/NeuralNetworks/FullyConnected.cs b/test/TensorFlowNET.Examples/NeuralNetworks/FullyConnected.cs new file mode 100644 index 00000000..faec454c --- /dev/null +++ b/test/TensorFlowNET.Examples/NeuralNetworks/FullyConnected.cs @@ -0,0 +1,154 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.Examples +{ + /// + /// How to optimise your input pipeline with queues and multi-threading + /// https://blog.metaflow.fr/tensorflow-how-to-optimise-your-input-pipeline-with-queues-and-multi-threading-e7c3874157e0 + /// + public class FullyConnected : IExample + { + public bool Enabled { get; set; } = true; + public bool IsImportingGraph { get; set; } + + public string Name => "Fully Connected Neural Network"; + + Tensor input = null; + Tensor x_inputs_data = null; + Tensor y_inputs_data = null; + Tensor accuracy = null; + Tensor y_true = null; + Tensor loss_op = null; + Operation train_op = null; + + public Graph BuildGraph() + { + var g = tf.get_default_graph(); + // batches of 128 samples, each containing 1024 data points + x_inputs_data = tf.random_normal(new[] { 128, 1024 }, mean: 0, stddev: 1); + // We will try to predict this law: + // predict 1 if the sum of the elements is positive and 0 otherwise + y_inputs_data = tf.cast(tf.reduce_sum(x_inputs_data, axis: 1, keepdims: true) > 0, tf.int32); + + Tensor z = null; + + + tf_with(tf.variable_scope("placeholder"), delegate + { + input = tf.placeholder(tf.float32, shape: (-1, 1024)); + y_true = tf.placeholder(tf.int32, shape: (-1, 1)); + }); + + tf_with(tf.variable_scope("FullyConnected"), delegate + { + var w = tf.get_variable("w", shape: (1024, 1024), initializer: tf.random_normal_initializer(stddev: 0.1f)); + var b = tf.get_variable("b", shape: 1024, initializer: tf.constant_initializer(0.1)); + z = tf.matmul(input, w) + b; + var y = tf.nn.relu(z); + + var w2 = tf.get_variable("w2", shape: (1024, 1), initializer: tf.random_normal_initializer(stddev: 0.1f)); + var b2 = tf.get_variable("b2", shape: 1, initializer: tf.constant_initializer(0.1)); + z = tf.matmul(y, w2) + b2; + }); + + tf_with(tf.variable_scope("Loss"), delegate + { + var losses = tf.nn.sigmoid_cross_entropy_with_logits(tf.cast(y_true, tf.float32), z); + loss_op = tf.reduce_mean(losses); + }); + + tf_with(tf.variable_scope("Accuracy"), delegate + { + var y_pred = tf.cast(z > 0, tf.int32); + accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred, y_true), tf.float32)); + // accuracy = tf.Print(accuracy, data =[accuracy], message = "accuracy:") + }); + + // We add the training operation, ... + var adam = tf.train.AdamOptimizer(0.01f); + train_op = adam.minimize(loss_op, name: "train_op"); + + return g; + } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public void Predict(Session sess) + { + throw new NotImplementedException(); + } + + public void PrepareData() + { + throw new NotImplementedException(); + } + + public bool Run() + { + var g = BuildGraph(); + using (var sess = tf.Session()) + Train(sess); + return true; + } + + public void Test(Session sess) + { + throw new NotImplementedException(); + } + + public void Train(Session sess) + { + var sw = new Stopwatch(); + sw.Start(); + // init variables + sess.run(tf.global_variables_initializer()); + + // check the accuracy before training + var (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data)); + sess.run(accuracy, (input, x_input), (y_true, y_input)); + + // training + foreach (var i in range(5000)) + { + // by sampling some input data (fetching) + (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data)); + var (_, loss) = sess.run((train_op, loss_op), (input, x_input), (y_true, y_input)); + + // We regularly check the loss + if (i % 500 == 0) + print($"iter:{i} - loss:{loss}"); + } + + // Finally, we check our final accuracy + (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data)); + sess.run(accuracy, (input, x_input), (y_true, y_input)); + + print($"Time taken: {sw.Elapsed.TotalSeconds}s"); + } + } +} diff --git a/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs b/test/TensorFlowNET.Examples/NeuralNetworks/NeuralNetXor.cs similarity index 100% rename from test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs rename to test/TensorFlowNET.Examples/NeuralNetworks/NeuralNetXor.cs diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj index 2f248b5d..43b5a699 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp2.2 + netcoreapp3.0 false diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index f4c90a7c..39cc1eeb 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp2.2 + netcoreapp3.0 false diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index d2ea6ebf..71fb3abb 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -1,7 +1,7 @@  - netcoreapp2.2 + netcoreapp3.0 false