diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
index f2749abc..956c52be 100644
--- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
@@ -51,6 +51,9 @@ namespace Tensorflow
public Tensor inv(Tensor input, bool adjoint = false, string name = null)
=> ops.matrix_inverse(input, adjoint: adjoint, name: name);
+ public Tensor global_norm(Tensor[] t_list, string name = null)
+ => clip_ops.global_norm(t_list, name: name);
+
public Tensor lstsq(Tensor matrix, Tensor rhs,
NDArray l2_regularizer = null, bool fast = true, string name = null)
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);
diff --git a/src/TensorFlowNET.Core/Operations/clip_ops.cs b/src/TensorFlowNET.Core/Operations/clip_ops.cs
index b782c469..7b48c9e5 100644
--- a/src/TensorFlowNET.Core/Operations/clip_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/clip_ops.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using System.Linq;
using static Tensorflow.Binding;
namespace Tensorflow
@@ -36,5 +37,24 @@ namespace Tensorflow
return t_max;
});
}
+
+ ///
+ /// Computes the global norm of multiple tensors.
+ ///
+ ///
+ ///
+ ///
+ public static Tensor global_norm(Tensor[] t_list, string name = null)
+ {
+ return tf_with(ops.name_scope(name, "global_norm", t_list), delegate
+ {
+ var half_squared_norms = t_list.Select(v => nn_ops.l2_loss(v)).ToArray();
+ var half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms));
+ var norm = math_ops.sqrt(half_squared_norm *
+ constant_op.constant(2.0, dtype: half_squared_norm.dtype),
+ name: "global_norm");
+ return norm;
+ });
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
index 790221f8..5eb05eaa 100644
--- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
+++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
@@ -75,11 +75,19 @@ namespace Tensorflow.Keras.Engine
metric_obj = keras.metrics.sparse_categorical_accuracy;
else
metric_obj = keras.metrics.categorical_accuracy;
+
+ metric = "accuracy";
}
else if(metric == "mean_absolute_error" || metric == "mae")
+ {
metric_obj = keras.metrics.mean_absolute_error;
+ metric = "mean_absolute_error";
+ }
else if (metric == "mean_absolute_percentage_error" || metric == "mape")
+ {
metric_obj = keras.metrics.mean_absolute_percentage_error;
+ metric = "mean_absolute_percentage_error";
+ }
else
throw new NotImplementedException("");
diff --git a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
index d724b333..c422bfa6 100644
--- a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
@@ -4,7 +4,6 @@ namespace Tensorflow.Keras.Metrics
{
public class MeanMetricWrapper : Mean
{
- string name;
Func _fn = null;
public MeanMetricWrapper(Func fn, string name, TF_DataType dtype = TF_DataType.TF_FLOAT)
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs
index eefc1c47..f7fb965b 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs
@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow;
using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.ManagedAPI
@@ -54,5 +55,13 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var e = tf.linalg.einsum("ij,jk->ik", (m0, m1));
Assert.AreEqual(e.shape, (2, 5));
}
+
+ [TestMethod]
+ public void GlobalNorm()
+ {
+ var t_list = new Tensors(tf.constant(new float[] { 1, 2, 3, 4 }), tf.constant(new float[] { 5, 6, 7, 8 }));
+ var norm = tf.linalg.global_norm(t_list);
+ Assert.AreEqual(norm.numpy(), 14.282857f);
+ }
}
}