From 98919983b14124a22a91f4430c71857e384d7127 Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Tue, 21 Feb 2023 21:55:59 -0600 Subject: [PATCH] Add metrics of Precision. --- src/TensorFlowNET.Core/APIs/tf.math.cs | 11 ++++ .../Keras/Metrics/IMetricsApi.cs | 13 ++++- src/TensorFlowNET.Core/Operations/nn_ops.cs | 4 ++ src/TensorFlowNET.Keras/Metrics/MetricsApi.cs | 5 +- src/TensorFlowNET.Keras/Metrics/Precision.cs | 55 +++++++++++++++++++ .../Metrics/metrics_utils.cs | 22 +++++++- .../Metrics/MetricsTest.cs | 34 ++++++++++++ 7 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Metrics/Precision.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 7d3f6eff..0191f8d6 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -39,6 +39,17 @@ namespace Tensorflow public Tensor sum(Tensor x, Axis? axis = null, string name = null) => math_ops.reduce_sum(x, axis: axis, name: name); + /// + /// Finds values and indices of the `k` largest entries for the last dimension. + /// + /// + /// + /// + /// + /// + public Tensors top_k(Tensor input, int k, bool sorted = true, string name = null) + => nn_ops.top_kv2(input, k, sorted: sorted, name: name); + public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK") => nn_ops.in_top_k(predictions, targets, k, name); diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs index 95cc1e60..e27c198d 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs @@ -36,6 +36,17 @@ public interface IMetricsApi /// IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); + /// + /// Computes the precision of the predictions with respect to the labels. + /// + /// + /// + /// + /// + /// + /// + IMetricFunc Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT); + /// /// Computes the recall of the predictions with respect to the labels. /// @@ -45,5 +56,5 @@ public interface IMetricsApi /// /// /// - IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT); + IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT); } diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index 307b1f8a..5877d234 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -109,6 +109,10 @@ namespace Tensorflow return noise_shape; } + public static Tensors top_kv2(Tensor input, int k, bool sorted = true, string name = null) + => tf.Context.ExecuteOp("TopKV2", name, new ExecuteOpArgs(input, k) + .SetAttributes(new { sorted })); + public static Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = null) { return tf_with(ops.name_scope(name, "in_top_k"), delegate diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index 23742832..74db1666 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -62,7 +62,10 @@ public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); - public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) + public IMetricFunc Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "precision", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new Precision(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype); + + public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) => new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype); } } diff --git a/src/TensorFlowNET.Keras/Metrics/Precision.cs b/src/TensorFlowNET.Keras/Metrics/Precision.cs new file mode 100644 index 00000000..a01773e0 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/Precision.cs @@ -0,0 +1,55 @@ +namespace Tensorflow.Keras.Metrics; + +public class Precision : Metric +{ + Tensor _thresholds; + int _top_k; + int _class_id; + IVariableV1 true_positives; + IVariableV1 false_positives; + bool _thresholds_distributed_evenly; + + public Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(name: name, dtype: dtype) + { + _thresholds = constant_op.constant(new float[] { thresholds }); + _top_k = top_k; + _class_id = class_id; + true_positives = add_weight("true_positives", shape: 1, initializer: tf.initializers.zeros_initializer()); + false_positives = add_weight("false_positives", shape: 1, initializer: tf.initializers.zeros_initializer()); + } + + public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + return metrics_utils.update_confusion_matrix_variables( + new Dictionary + { + { "tp", true_positives }, + { "fp", false_positives }, + }, + y_true, + y_pred, + thresholds: _thresholds, + thresholds_distributed_evenly: _thresholds_distributed_evenly, + top_k: _top_k, + class_id: _class_id, + sample_weight: sample_weight); + } + + public override Tensor result() + { + var result = tf.divide(true_positives.AsTensor(), tf.add(true_positives, false_positives)); + return _thresholds.size == 1 ? result[0] : result; + } + + public override void reset_states() + { + var num_thresholds = (int)_thresholds.size; + keras.backend.batch_set_value( + new List<(IVariableV1, NDArray)> + { + (true_positives, np.zeros(num_thresholds)), + (false_positives, np.zeros(num_thresholds)) + }); + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs index d09b3c72..0251462e 100644 --- a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs +++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs @@ -78,6 +78,17 @@ public class metrics_utils sample_weight: sample_weight); } + if (top_k > 0) + { + y_pred = _filter_top_k(y_pred, top_k); + } + + if (class_id > 0) + { + y_true = y_true[Slice.All, class_id]; + y_pred = y_pred[Slice.All, class_id]; + } + if (thresholds_distributed_evenly) { throw new NotImplementedException(); @@ -204,5 +215,14 @@ public class metrics_utils tf.group(update_ops.ToArray()); return null; - } + } + + private static Tensor _filter_top_k(Tensor x, int k) + { + var NEG_INF = -1e10; + var (_, top_k_idx) = tf.math.top_k(x, k, sorted: false); + var top_k_mask = tf.reduce_sum( + tf.one_hot(top_k_idx, (int)x.shape[-1], axis: -1), axis: -2); + return x * top_k_mask + NEG_INF * (1 - top_k_mask); + } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs index 84382bb4..f3ba2e93 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs @@ -46,6 +46,40 @@ public class MetricsTest : EagerModeTestBase Assert.AreEqual(m.numpy(), new[] { 1f, 1f }); } + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision + /// + [TestMethod] + public void Precision() + { + var y_true = np.array(new[] { 0, 1, 1, 1 }); + var y_pred = np.array(new[] { 1, 0, 1, 1 }); + var m = tf.keras.metrics.Precision(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.6666667f); + + m.reset_states(); + var weights = np.array(new[] { 0f, 0f, 1f, 0f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 1f); + + // With top_k=2, it will calculate precision over y_true[:2] + // and y_pred[:2] + m = tf.keras.metrics.Precision(top_k: 2); + m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 })); + r = m.result().numpy(); + Assert.AreEqual(r, 0f); + + // With top_k=4, it will calculate precision over y_true[:4] + // and y_pred[:4] + m = tf.keras.metrics.Precision(top_k: 4); + m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 })); + r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + } + /// /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall ///