From ca9f574fce755dd92f365d732b1ff1a20b568ecf Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Sun, 19 Feb 2023 13:32:21 -0600 Subject: [PATCH] Add metric of top_k_categorical_accuracy. --- src/TensorFlowNET.Core/APIs/tf.math.cs | 3 ++ src/TensorFlowNET.Core/Keras/IKerasApi.cs | 2 + .../Keras/Metrics/IMetricsApi.cs | 29 ++++++++++++++ .../Operations/NnOps/gen_nn_ops.cs | 12 +----- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 5 +++ src/TensorFlowNET.Keras/KerasInterface.cs | 2 +- src/TensorFlowNET.Keras/Metrics/MetricsApi.cs | 9 ++++- .../Metrics/metrics_utils.cs | 39 +++++++++++++++++++ .../Metrics/MetricsTest.cs | 28 +++++++++++++ 9 files changed, 117 insertions(+), 12 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs create mode 100644 src/TensorFlowNET.Keras/Metrics/metrics_utils.cs create mode 100644 test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index ce6dc4d6..7d3f6eff 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -39,6 +39,9 @@ namespace Tensorflow public Tensor sum(Tensor x, Axis? axis = null, string name = null) => math_ops.reduce_sum(x, axis: axis, name: name); + public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK") + => nn_ops.in_top_k(predictions, targets, k, name); + /// /// /// diff --git a/src/TensorFlowNET.Core/Keras/IKerasApi.cs b/src/TensorFlowNET.Core/Keras/IKerasApi.cs index 49ec9a5f..cffd3f79 100644 --- a/src/TensorFlowNET.Core/Keras/IKerasApi.cs +++ b/src/TensorFlowNET.Core/Keras/IKerasApi.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; namespace Tensorflow.Keras { @@ -10,6 +11,7 @@ namespace Tensorflow.Keras { public ILayersApi layers { get; } public ILossesApi losses { get; } + public IMetricsApi metrics { get; } public IInitializersApi initializers { get; } } } diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs new file mode 100644 index 00000000..2fe6d809 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs @@ -0,0 +1,29 @@ +namespace Tensorflow.Keras.Metrics; + +public interface IMetricsApi +{ + Tensor binary_accuracy(Tensor y_true, Tensor y_pred); + + Tensor categorical_accuracy(Tensor y_true, Tensor y_pred); + + Tensor mean_absolute_error(Tensor y_true, Tensor y_pred); + + Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred); + + /// + /// Calculates how often predictions matches integer labels. + /// + /// Integer ground truth values. + /// The prediction values. + /// Sparse categorical accuracy values. + Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred); + + /// + /// Computes how often targets are in the top `K` predictions. + /// + /// + /// + /// + /// + Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5); +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 0567858f..408d06eb 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -240,16 +240,8 @@ namespace Tensorflow.Operations /// /// A `Tensor` of type `bool`. public static Tensor in_top_kv2(Tensor predictions, Tensor targets, int k, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("InTopKV2", name: name, args: new - { - predictions, - targets, - k - }); - - return _op.output; - } + => tf.Context.ExecuteOp("InTopKV2", name, + new ExecuteOpArgs(predictions, targets, k)); public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) => tf.Context.ExecuteOp("LeakyRelu", name, diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 5f09f202..7af89f13 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -121,6 +121,11 @@ namespace Tensorflow if (dtype == TF_DataType.TF_INT32) values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray(); } + else if (values is double[] double_values) + { + if (dtype == TF_DataType.TF_FLOAT) + values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray(); + } else values = Convert.ChangeType(values, new_system_dtype); diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs index 4e0c612b..e0d148ce 100644 --- a/src/TensorFlowNET.Keras/KerasInterface.cs +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -27,7 +27,7 @@ namespace Tensorflow.Keras ThreadLocal _backend = new ThreadLocal(() => new BackendImpl()); public BackendImpl backend => _backend.Value; public OptimizerApi optimizers { get; } = new OptimizerApi(); - public MetricsApi metrics { get; } = new MetricsApi(); + public IMetricsApi metrics { get; } = new MetricsApi(); public ModelsApi models { get; } = new ModelsApi(); public KerasUtils utils { get; } = new KerasUtils(); diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index 3d614e02..6b0e2d8a 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -2,7 +2,7 @@ namespace Tensorflow.Keras.Metrics { - public class MetricsApi + public class MetricsApi : IMetricsApi { public Tensor binary_accuracy(Tensor y_true, Tensor y_pred) { @@ -53,5 +53,12 @@ namespace Tensorflow.Keras.Metrics var diff = (y_true - y_pred) / math_ops.maximum(math_ops.abs(y_true), keras.backend.epsilon()); return 100f * keras.backend.mean(math_ops.abs(diff), axis: -1); } + + public Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5) + { + return metrics_utils.sparse_top_k_categorical_matches( + tf.math.argmax(y_true, axis: -1), y_pred, k + ); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs new file mode 100644 index 00000000..de6a8402 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs @@ -0,0 +1,39 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.Metrics; + +public class metrics_utils +{ + public static Tensor sparse_top_k_categorical_matches(Tensor y_true, Tensor y_pred, int k = 5) + { + var reshape_matches = false; + var y_true_rank = y_true.shape.ndim; + var y_pred_rank = y_pred.shape.ndim; + var y_true_org_shape = tf.shape(y_true); + + if (y_pred_rank > 2) + { + y_pred = tf.reshape(y_pred, (-1, y_pred.shape[-1])); + } + + if (y_true_rank > 1) + { + reshape_matches = true; + y_true = tf.reshape(y_true, new Shape(-1)); + } + + var matches = tf.cast( + tf.math.in_top_k( + predictions: y_pred, targets: tf.cast(y_true, np.int32), k: k + ), + dtype: keras.backend.floatx() + ); + + if (reshape_matches) + { + return tf.reshape(matches, shape: y_true_org_shape); + } + + return matches; + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs new file mode 100644 index 00000000..bb0107d4 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs @@ -0,0 +1,28 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.Keras.UnitTest; + +[TestClass] +public class MetricsTest : EagerModeTestBase +{ + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy + /// + [TestMethod] + public void top_k_categorical_accuracy() + { + var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } }); + var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); + var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3); + Assert.AreEqual(m.numpy(), new[] { 1f, 1f }); + } +}