diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs
new file mode 100644
index 00000000..1867d637
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs
@@ -0,0 +1,17 @@
+namespace Tensorflow.Keras.Metrics;
+
+public interface IMetricFunc
+{
+ ///
+ /// Accumulates metric statistics.
+ ///
+ ///
+ ///
+ ///
+ ///
+ Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null);
+
+ Tensor result();
+
+ void reset_states();
+}
diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
index 2fe6d809..511b0ef1 100644
--- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
@@ -26,4 +26,13 @@ public interface IMetricsApi
///
///
Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5);
+
+ ///
+ /// Computes how often targets are in the top K predictions.
+ ///
+ ///
+ ///
+ ///
+ ///
+ IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT);
}
diff --git a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
index c422bfa6..2e985b88 100644
--- a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
@@ -1,4 +1,5 @@
using System;
+using Tensorflow.Keras.Utils;
namespace Tensorflow.Keras.Metrics
{
@@ -17,6 +18,8 @@ namespace Tensorflow.Keras.Metrics
y_true = math_ops.cast(y_true, _dtype);
y_pred = math_ops.cast(y_pred, _dtype);
+ (y_pred, y_true) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true);
+
var matches = _fn(y_true, y_pred);
return update_state(matches, sample_weight: sample_weight);
}
diff --git a/src/TensorFlowNET.Keras/Metrics/Metric.cs b/src/TensorFlowNET.Keras/Metrics/Metric.cs
index 21457f15..1dfc39c4 100644
--- a/src/TensorFlowNET.Keras/Metrics/Metric.cs
+++ b/src/TensorFlowNET.Keras/Metrics/Metric.cs
@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.Metrics
///
/// Encapsulates metric logic and state.
///
- public class Metric : Layer
+ public class Metric : Layer, IMetricFunc
{
protected IVariableV1 total;
protected IVariableV1 count;
diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
index 6b0e2d8a..dfccfdbb 100644
--- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
@@ -1,6 +1,4 @@
-using static Tensorflow.KerasApi;
-
-namespace Tensorflow.Keras.Metrics
+namespace Tensorflow.Keras.Metrics
{
public class MetricsApi : IMetricsApi
{
@@ -60,5 +58,8 @@ namespace Tensorflow.Keras.Metrics
tf.math.argmax(y_true, axis: -1), y_pred, k
);
}
+
+ public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
}
}
diff --git a/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs
new file mode 100644
index 00000000..63e94102
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs
@@ -0,0 +1,12 @@
+namespace Tensorflow.Keras.Metrics;
+
+public class TopKCategoricalAccuracy : MeanMetricWrapper
+{
+ public TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ : base((yt, yp) => metrics_utils.sparse_top_k_categorical_matches(
+ tf.math.argmax(yt, axis: -1), yp, k),
+ name: name,
+ dtype: dtype)
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Keras/Utils/losses_utils.cs
index 08330595..6de98861 100644
--- a/src/TensorFlowNET.Keras/Utils/losses_utils.cs
+++ b/src/TensorFlowNET.Keras/Utils/losses_utils.cs
@@ -15,6 +15,7 @@
******************************************************************************/
using System;
+using System.Xml.Linq;
using Tensorflow.Keras.Losses;
using static Tensorflow.Binding;
@@ -37,15 +38,57 @@ namespace Tensorflow.Keras.Utils
});
}
- public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight)
+ public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null)
{
+ var y_pred_shape = y_pred.shape;
+ var y_pred_rank = y_pred_shape.ndim;
+ if (y_true != null)
+ {
+ var y_true_shape = y_true.shape;
+ var y_true_rank = y_true_shape.ndim;
+ if (y_true_rank > -1 && y_pred_rank > -1)
+ {
+ if (y_pred_rank - y_true_rank != 1 || y_pred_shape[-1] == 1)
+ {
+ (y_true, y_pred) = remove_squeezable_dimensions(y_true, y_pred);
+ }
+ }
+ }
+
+ if (sample_weight == null)
+ {
+ return (y_pred, y_true);
+ }
+
var weights_shape = sample_weight.shape;
var weights_rank = weights_shape.ndim;
if (weights_rank == 0)
return (y_pred, sample_weight);
+
+ if (y_pred_rank > -1 && weights_rank > -1)
+ {
+ if (weights_rank - y_pred_rank == 1)
+ {
+ sample_weight = tf.squeeze(sample_weight, -1);
+ }
+ else if (y_pred_rank - weights_rank == 1)
+ {
+ sample_weight = tf.expand_dims(sample_weight, -1);
+ }
+ else
+ {
+ return (y_pred, sample_weight);
+ }
+ }
+
throw new NotImplementedException("");
}
+ public static (Tensor, Tensor) remove_squeezable_dimensions(Tensor labels, Tensor predictions, int expected_rank_diff = 0, string name = null)
+ {
+ return (labels, predictions);
+ }
+
public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction)
{
if (reduction == ReductionV2.NONE)
diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
index bb0107d4..a35763d0 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
@@ -14,6 +14,26 @@ namespace TensorFlowNET.Keras.UnitTest;
[TestClass]
public class MetricsTest : EagerModeTestBase
{
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
+ ///
+ [TestMethod]
+ public void TopKCategoricalAccuracy()
+ {
+ var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
+ var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
+ var m = tf.keras.metrics.TopKCategoricalAccuracy(k: 1);
+ m.update_state(y_true, y_pred);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 0.5f);
+
+ m.reset_states();
+ var weights = np.array(new[] { 0.7f, 0.3f });
+ m.update_state(y_true, y_pred, sample_weight: weights);
+ r = m.result().numpy();
+ Assert.AreEqual(r, 0.3f);
+ }
+
///
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy
///