From a5289b9bb3ab98f54186d0627b2a8dde5c1e215e Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Sun, 19 Feb 2023 15:43:47 -0600 Subject: [PATCH] Abstract IMetricFunc. --- .../Keras/Metrics/IMetricFunc.cs | 17 +++++++ .../Keras/Metrics/IMetricsApi.cs | 9 ++++ .../Metrics/MeanMetricWrapper.cs | 3 ++ src/TensorFlowNET.Keras/Metrics/Metric.cs | 2 +- src/TensorFlowNET.Keras/Metrics/MetricsApi.cs | 7 +-- .../Metrics/TopKCategoricalAccuracy.cs | 12 +++++ src/TensorFlowNET.Keras/Utils/losses_utils.cs | 45 ++++++++++++++++++- .../Metrics/MetricsTest.cs | 20 +++++++++ 8 files changed, 110 insertions(+), 5 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs create mode 100644 src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs new file mode 100644 index 00000000..1867d637 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs @@ -0,0 +1,17 @@ +namespace Tensorflow.Keras.Metrics; + +public interface IMetricFunc +{ + /// + /// Accumulates metric statistics. + /// + /// + /// + /// + /// + Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null); + + Tensor result(); + + void reset_states(); +} diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs index 2fe6d809..511b0ef1 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs @@ -26,4 +26,13 @@ public interface IMetricsApi /// /// Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5); + + /// + /// Computes how often targets are in the top K predictions. + /// + /// + /// + /// + /// + IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); } diff --git a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs index c422bfa6..2e985b88 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs @@ -1,4 +1,5 @@ using System; +using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Metrics { @@ -17,6 +18,8 @@ namespace Tensorflow.Keras.Metrics y_true = math_ops.cast(y_true, _dtype); y_pred = math_ops.cast(y_pred, _dtype); + (y_pred, y_true) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true); + var matches = _fn(y_true, y_pred); return update_state(matches, sample_weight: sample_weight); } diff --git a/src/TensorFlowNET.Keras/Metrics/Metric.cs b/src/TensorFlowNET.Keras/Metrics/Metric.cs index 21457f15..1dfc39c4 100644 --- a/src/TensorFlowNET.Keras/Metrics/Metric.cs +++ b/src/TensorFlowNET.Keras/Metrics/Metric.cs @@ -9,7 +9,7 @@ namespace Tensorflow.Keras.Metrics /// /// Encapsulates metric logic and state. /// - public class Metric : Layer + public class Metric : Layer, IMetricFunc { protected IVariableV1 total; protected IVariableV1 count; diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index 6b0e2d8a..dfccfdbb 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -1,6 +1,4 @@ -using static Tensorflow.KerasApi; - -namespace Tensorflow.Keras.Metrics +namespace Tensorflow.Keras.Metrics { public class MetricsApi : IMetricsApi { @@ -60,5 +58,8 @@ namespace Tensorflow.Keras.Metrics tf.math.argmax(y_true, axis: -1), y_pred, k ); } + + public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); } } diff --git a/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs new file mode 100644 index 00000000..63e94102 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs @@ -0,0 +1,12 @@ +namespace Tensorflow.Keras.Metrics; + +public class TopKCategoricalAccuracy : MeanMetricWrapper +{ + public TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base((yt, yp) => metrics_utils.sparse_top_k_categorical_matches( + tf.math.argmax(yt, axis: -1), yp, k), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Keras/Utils/losses_utils.cs index 08330595..6de98861 100644 --- a/src/TensorFlowNET.Keras/Utils/losses_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/losses_utils.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Xml.Linq; using Tensorflow.Keras.Losses; using static Tensorflow.Binding; @@ -37,15 +38,57 @@ namespace Tensorflow.Keras.Utils }); } - public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight) + public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) { + var y_pred_shape = y_pred.shape; + var y_pred_rank = y_pred_shape.ndim; + if (y_true != null) + { + var y_true_shape = y_true.shape; + var y_true_rank = y_true_shape.ndim; + if (y_true_rank > -1 && y_pred_rank > -1) + { + if (y_pred_rank - y_true_rank != 1 || y_pred_shape[-1] == 1) + { + (y_true, y_pred) = remove_squeezable_dimensions(y_true, y_pred); + } + } + } + + if (sample_weight == null) + { + return (y_pred, y_true); + } + var weights_shape = sample_weight.shape; var weights_rank = weights_shape.ndim; if (weights_rank == 0) return (y_pred, sample_weight); + + if (y_pred_rank > -1 && weights_rank > -1) + { + if (weights_rank - y_pred_rank == 1) + { + sample_weight = tf.squeeze(sample_weight, -1); + } + else if (y_pred_rank - weights_rank == 1) + { + sample_weight = tf.expand_dims(sample_weight, -1); + } + else + { + return (y_pred, sample_weight); + } + } + throw new NotImplementedException(""); } + public static (Tensor, Tensor) remove_squeezable_dimensions(Tensor labels, Tensor predictions, int expected_rank_diff = 0, string name = null) + { + return (labels, predictions); + } + public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) { if (reduction == ReductionV2.NONE) diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs index bb0107d4..a35763d0 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs @@ -14,6 +14,26 @@ namespace TensorFlowNET.Keras.UnitTest; [TestClass] public class MetricsTest : EagerModeTestBase { + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy + /// + [TestMethod] + public void TopKCategoricalAccuracy() + { + var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } }); + var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); + var m = tf.keras.metrics.TopKCategoricalAccuracy(k: 1); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + + m.reset_states(); + var weights = np.array(new[] { 0.7f, 0.3f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.3f); + } + /// /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy ///