diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs index f2749abc..956c52be 100644 --- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs +++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs @@ -51,6 +51,9 @@ namespace Tensorflow public Tensor inv(Tensor input, bool adjoint = false, string name = null) => ops.matrix_inverse(input, adjoint: adjoint, name: name); + public Tensor global_norm(Tensor[] t_list, string name = null) + => clip_ops.global_norm(t_list, name: name); + public Tensor lstsq(Tensor matrix, Tensor rhs, NDArray l2_regularizer = null, bool fast = true, string name = null) => ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name); diff --git a/src/TensorFlowNET.Core/Operations/clip_ops.cs b/src/TensorFlowNET.Core/Operations/clip_ops.cs index b782c469..7b48c9e5 100644 --- a/src/TensorFlowNET.Core/Operations/clip_ops.cs +++ b/src/TensorFlowNET.Core/Operations/clip_ops.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System.Linq; using static Tensorflow.Binding; namespace Tensorflow @@ -36,5 +37,24 @@ namespace Tensorflow return t_max; }); } + + /// + /// Computes the global norm of multiple tensors. + /// + /// + /// + /// + public static Tensor global_norm(Tensor[] t_list, string name = null) + { + return tf_with(ops.name_scope(name, "global_norm", t_list), delegate + { + var half_squared_norms = t_list.Select(v => nn_ops.l2_loss(v)).ToArray(); + var half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms)); + var norm = math_ops.sqrt(half_squared_norm * + constant_op.constant(2.0, dtype: half_squared_norm.dtype), + name: "global_norm"); + return norm; + }); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs index 790221f8..5eb05eaa 100644 --- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs @@ -75,11 +75,19 @@ namespace Tensorflow.Keras.Engine metric_obj = keras.metrics.sparse_categorical_accuracy; else metric_obj = keras.metrics.categorical_accuracy; + + metric = "accuracy"; } else if(metric == "mean_absolute_error" || metric == "mae") + { metric_obj = keras.metrics.mean_absolute_error; + metric = "mean_absolute_error"; + } else if (metric == "mean_absolute_percentage_error" || metric == "mape") + { metric_obj = keras.metrics.mean_absolute_percentage_error; + metric = "mean_absolute_percentage_error"; + } else throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs index d724b333..c422bfa6 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs @@ -4,7 +4,6 @@ namespace Tensorflow.Keras.Metrics { public class MeanMetricWrapper : Mean { - string name; Func _fn = null; public MeanMetricWrapper(Func fn, string name, TF_DataType dtype = TF_DataType.TF_FLOAT) diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs index eefc1c47..f7fb965b 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.ManagedAPI @@ -54,5 +55,13 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var e = tf.linalg.einsum("ij,jk->ik", (m0, m1)); Assert.AreEqual(e.shape, (2, 5)); } + + [TestMethod] + public void GlobalNorm() + { + var t_list = new Tensors(tf.constant(new float[] { 1, 2, 3, 4 }), tf.constant(new float[] { 5, 6, 7, 8 })); + var norm = tf.linalg.global_norm(t_list); + Assert.AreEqual(norm.numpy(), 14.282857f); + } } }