diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index e62f6358..cba05d0e 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -431,17 +431,19 @@ namespace Tensorflow
///
///
///
- public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null)
+ public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null,
+ bool keepdims = false, string name = null)
{
if(!axis.HasValue && reduction_indices.HasValue)
return math_ops.reduce_sum(input, reduction_indices.Value);
else if (axis.HasValue && !reduction_indices.HasValue)
return math_ops.reduce_sum(input, axis.Value);
- return math_ops.reduce_sum(input);
+ return math_ops.reduce_sum(input, keepdims: keepdims, name: name);
}
- public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null)
- => math_ops.reduce_sum(input, axis);
+ public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null,
+ bool keepdims = false, string name = null)
+ => math_ops.reduce_sum(input, axis, keepdims: keepdims, name: name);
///
/// Computes the maximum of elements across dimensions of a tensor.
diff --git a/src/TensorFlowNET.Core/Exceptions/LookupError.cs b/src/TensorFlowNET.Core/Exceptions/LookupError.cs
new file mode 100644
index 00000000..ebbaa526
--- /dev/null
+++ b/src/TensorFlowNET.Core/Exceptions/LookupError.cs
@@ -0,0 +1,17 @@
+using System;
+
+namespace Tensorflow
+{
+ public class LookupError : TensorflowException
+ {
+ public LookupError() : base()
+ {
+
+ }
+
+ public LookupError(string message) : base(message)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs b/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs
new file mode 100644
index 00000000..d573e317
--- /dev/null
+++ b/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs
@@ -0,0 +1,33 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+
+namespace Tensorflow.Gradients
+{
+ ///
+ /// REGISTER_NO_GRADIENT_OP("");
+ ///
+ public class RegisterNoGradient : Attribute
+ {
+ public string Name { get; set; }
+
+ public RegisterNoGradient(string name)
+ {
+ Name = name;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs
index 5aa0d044..bfa1d296 100644
--- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs
+++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs
@@ -117,19 +117,44 @@ namespace Tensorflow
Tensor[] in_grads = null;
var is_partitioned_call = _IsPartitionedCall(op);
var is_func_call = false;
- var has_out_grads = true;
+ var has_out_grads = out_grads.Exists(x => x != null);
if (has_out_grads && !stop_ops.Contains(op))
{
// A grad_fn must be defined, either as a function or as None
// for ops that do not have gradients.
- var grad_fn = ops.get_gradient_function(op);
- if (is_func_call)
+ Func grad_fn = null;
+ try
{
+ grad_fn = ops.get_gradient_function(op);
+ }
+ catch (LookupError)
+ {
+ if (is_func_call)
+ {
+ if (is_partitioned_call)
+ {
+ }
+ else
+ {
+
+ }
+ }
+ else
+ {
+ throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})");
+ }
}
- else
+
+ // if (loop_state)
+ //loop_state.EnterGradWhileContext(op, before: false);
+
+ if ((is_func_call || grad_fn != null) && has_out_grads)
{
+ // NOTE: If _AggregatedGrads didn't compute a value for the i'th
+ // output, it means that the cost does not depend on output[i],
+ // therefore dC/doutput[i] is 0.
foreach (var (i, out_grad) in enumerate(out_grads))
{
if (out_grad == null)
@@ -143,13 +168,11 @@ namespace Tensorflow
tf_with(ops.name_scope(op.name + "_grad"), scope1 =>
{
- string name1 = scope1;
if (grad_fn != null)
{
in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn);
- _VerifyGeneratedGradients(in_grads, op);
}
-
+ _VerifyGeneratedGradients(in_grads, op);
if (gate_gradients && in_grads.Count(x => x != null) > 1)
{
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true);
@@ -157,6 +180,12 @@ namespace Tensorflow
}
});
}
+ else
+ {
+ // If no grad_fn is defined or none of out_grads is available,
+ // just propagate a list of None backwards.
+ in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
+ }
}
else
{
diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs
index 85f849cc..0e9e09e5 100644
--- a/src/TensorFlowNET.Core/Gradients/math_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs
@@ -96,12 +96,11 @@ namespace Tensorflow.Gradients
});
}
- [RegisterGradient("GreaterEqual")]
- public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads)
- {
- var grad = grads[0];
- throw new NotImplementedException("_GreaterEqualGrad");
- }
+ [RegisterNoGradient("GreaterEqual")]
+ public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null;
+
+ [RegisterNoGradient("ZerosLike")]
+ public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null;
[RegisterGradient("Identity")]
public static Tensor[] _IdGrad(Operation op, Tensor[] grads)
@@ -415,7 +414,9 @@ namespace Tensorflow.Gradients
var rank = input_0_shape.Length;
if (Enumerable.SequenceEqual(Enumerable.Range(0, rank), axes.Data()))
{
- grad = array_ops.reshape(grad, new int[] { 1 });
+ var new_shape = range(rank).Select(x => 1).ToArray();
+ grad = array_ops.reshape(grad, new_shape);
+ // If shape is not fully defined (but rank is), we use Shape.
if (!input_0_shape.Contains(-1))
input_shape = constant_op.constant(input_0_shape);
else
diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
index 0b624ba1..4891fcbb 100644
--- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
+++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
@@ -39,6 +39,14 @@ namespace Tensorflow
gradientFunctions[name] = func;
}
+ public static void RegisterNoGradientFunction(string name)
+ {
+ if (gradientFunctions == null)
+ gradientFunctions = new Dictionary>();
+
+ gradientFunctions[name] = null;
+ }
+
public static Func get_gradient_function(Operation op)
{
if (op.inputs == null) return null;
@@ -68,11 +76,18 @@ namespace Tensorflow
args: new object[] { oper, out_grads }) as Tensor[]
);
}
+
+ // REGISTER_NO_GRADIENT_OP
+ methods = g.GetMethods().Where(x => x.GetCustomAttribute() != null)
+ .ToArray();
+
+ foreach (var m in methods)
+ RegisterNoGradientFunction(m.GetCustomAttribute().Name);
}
}
if (!gradientFunctions.ContainsKey(op.type))
- throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}");
+ throw new LookupError($"can't get graident function through get_gradient_function {op.type}");
return gradientFunctions[op.type];
}
diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
index 239454b6..bced0047 100644
--- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
@@ -154,7 +154,7 @@ namespace Tensorflow
public static Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null)
{
- return tf_with(ops.name_scope(name, "", new { }), scope =>
+ return tf_with(ops.name_scope(name, "logistic_loss", new { logits, labels }), scope =>
{
name = scope;
logits = ops.convert_to_tensor(logits, name: "logits");
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 3594649d..13f56a60 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -19,7 +19,7 @@
Docs: https://tensorflownet.readthedocs.io
0.11.4.0
Changes since v0.10.0:
-1. Upgrade NumSharp to v0.20.
+1. Upgrade NumSharp to v0.20.3.
2. Add DisposableObject class to manage object lifetime.
3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables.
4. Change tensorflow to non-static class in order to execute some initialization process.
@@ -28,7 +28,8 @@ Docs: https://tensorflownet.readthedocs.io
7. Add tf.image related APIs.
8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor.
9. MultiThread is safe.
-10. Support n-dim indexing for tensor.
+10. Support n-dim indexing for tensor.
+11. Add RegisterNoGradient
7.3
0.11.4.0
LICENSE
@@ -62,7 +63,7 @@ Docs: https://tensorflownet.readthedocs.io
-
+
diff --git a/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj b/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj
index 4618f06b..b0e91991 100644
--- a/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj
+++ b/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj
@@ -2,7 +2,7 @@
Exe
- netcoreapp2.2
+ netcoreapp3.0
true
TensorFlowBenchmark
TensorFlowBenchmark
diff --git a/test/TensorFlowNET.Examples/NeuralNetworks/FullyConnected.cs b/test/TensorFlowNET.Examples/NeuralNetworks/FullyConnected.cs
new file mode 100644
index 00000000..faec454c
--- /dev/null
+++ b/test/TensorFlowNET.Examples/NeuralNetworks/FullyConnected.cs
@@ -0,0 +1,154 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using NumSharp;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Text;
+using Tensorflow;
+using static Tensorflow.Binding;
+
+namespace TensorFlowNET.Examples
+{
+ ///
+ /// How to optimise your input pipeline with queues and multi-threading
+ /// https://blog.metaflow.fr/tensorflow-how-to-optimise-your-input-pipeline-with-queues-and-multi-threading-e7c3874157e0
+ ///
+ public class FullyConnected : IExample
+ {
+ public bool Enabled { get; set; } = true;
+ public bool IsImportingGraph { get; set; }
+
+ public string Name => "Fully Connected Neural Network";
+
+ Tensor input = null;
+ Tensor x_inputs_data = null;
+ Tensor y_inputs_data = null;
+ Tensor accuracy = null;
+ Tensor y_true = null;
+ Tensor loss_op = null;
+ Operation train_op = null;
+
+ public Graph BuildGraph()
+ {
+ var g = tf.get_default_graph();
+ // batches of 128 samples, each containing 1024 data points
+ x_inputs_data = tf.random_normal(new[] { 128, 1024 }, mean: 0, stddev: 1);
+ // We will try to predict this law:
+ // predict 1 if the sum of the elements is positive and 0 otherwise
+ y_inputs_data = tf.cast(tf.reduce_sum(x_inputs_data, axis: 1, keepdims: true) > 0, tf.int32);
+
+ Tensor z = null;
+
+
+ tf_with(tf.variable_scope("placeholder"), delegate
+ {
+ input = tf.placeholder(tf.float32, shape: (-1, 1024));
+ y_true = tf.placeholder(tf.int32, shape: (-1, 1));
+ });
+
+ tf_with(tf.variable_scope("FullyConnected"), delegate
+ {
+ var w = tf.get_variable("w", shape: (1024, 1024), initializer: tf.random_normal_initializer(stddev: 0.1f));
+ var b = tf.get_variable("b", shape: 1024, initializer: tf.constant_initializer(0.1));
+ z = tf.matmul(input, w) + b;
+ var y = tf.nn.relu(z);
+
+ var w2 = tf.get_variable("w2", shape: (1024, 1), initializer: tf.random_normal_initializer(stddev: 0.1f));
+ var b2 = tf.get_variable("b2", shape: 1, initializer: tf.constant_initializer(0.1));
+ z = tf.matmul(y, w2) + b2;
+ });
+
+ tf_with(tf.variable_scope("Loss"), delegate
+ {
+ var losses = tf.nn.sigmoid_cross_entropy_with_logits(tf.cast(y_true, tf.float32), z);
+ loss_op = tf.reduce_mean(losses);
+ });
+
+ tf_with(tf.variable_scope("Accuracy"), delegate
+ {
+ var y_pred = tf.cast(z > 0, tf.int32);
+ accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred, y_true), tf.float32));
+ // accuracy = tf.Print(accuracy, data =[accuracy], message = "accuracy:")
+ });
+
+ // We add the training operation, ...
+ var adam = tf.train.AdamOptimizer(0.01f);
+ train_op = adam.minimize(loss_op, name: "train_op");
+
+ return g;
+ }
+
+ public Graph ImportGraph()
+ {
+ throw new NotImplementedException();
+ }
+
+ public void Predict(Session sess)
+ {
+ throw new NotImplementedException();
+ }
+
+ public void PrepareData()
+ {
+ throw new NotImplementedException();
+ }
+
+ public bool Run()
+ {
+ var g = BuildGraph();
+ using (var sess = tf.Session())
+ Train(sess);
+ return true;
+ }
+
+ public void Test(Session sess)
+ {
+ throw new NotImplementedException();
+ }
+
+ public void Train(Session sess)
+ {
+ var sw = new Stopwatch();
+ sw.Start();
+ // init variables
+ sess.run(tf.global_variables_initializer());
+
+ // check the accuracy before training
+ var (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data));
+ sess.run(accuracy, (input, x_input), (y_true, y_input));
+
+ // training
+ foreach (var i in range(5000))
+ {
+ // by sampling some input data (fetching)
+ (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data));
+ var (_, loss) = sess.run((train_op, loss_op), (input, x_input), (y_true, y_input));
+
+ // We regularly check the loss
+ if (i % 500 == 0)
+ print($"iter:{i} - loss:{loss}");
+ }
+
+ // Finally, we check our final accuracy
+ (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data));
+ sess.run(accuracy, (input, x_input), (y_true, y_input));
+
+ print($"Time taken: {sw.Elapsed.TotalSeconds}s");
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs b/test/TensorFlowNET.Examples/NeuralNetworks/NeuralNetXor.cs
similarity index 100%
rename from test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs
rename to test/TensorFlowNET.Examples/NeuralNetworks/NeuralNetXor.cs
diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj
index 2f248b5d..43b5a699 100644
--- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj
+++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj
@@ -2,7 +2,7 @@
Exe
- netcoreapp2.2
+ netcoreapp3.0
false
diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
index f4c90a7c..39cc1eeb 100644
--- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
+++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
@@ -2,7 +2,7 @@
Exe
- netcoreapp2.2
+ netcoreapp3.0
false
diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
index d2ea6ebf..71fb3abb 100644
--- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
+++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
@@ -1,7 +1,7 @@
- netcoreapp2.2
+ netcoreapp3.0
false