From 067c1ff92aaa35a65dc3e659111404cc8a8c052b Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Fri, 24 Feb 2023 17:35:48 -0600 Subject: [PATCH] Add metrics of F1Score and FBetaScore. --- src/TensorFlowNET.Core/APIs/tf.math.cs | 9 ++ .../Keras/Metrics/IMetricsApi.cs | 21 +++ .../Tensorflow.Binding.csproj | 6 +- src/TensorFlowNET.Keras/Metrics/F1Score.cs | 13 ++ src/TensorFlowNET.Keras/Metrics/FBetaScore.cs | 131 ++++++++++++++++++ src/TensorFlowNET.Keras/Metrics/MetricsApi.cs | 6 + .../Tensorflow.Keras.csproj | 6 +- .../Metrics/MetricsTest.cs | 28 ++++ 8 files changed, 214 insertions(+), 6 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Metrics/F1Score.cs create mode 100644 src/TensorFlowNET.Keras/Metrics/FBetaScore.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 0191f8d6..dabdf126 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -36,6 +36,15 @@ namespace Tensorflow public Tensor erf(Tensor x, string name = null) => math_ops.erf(x, name); + public Tensor multiply(Tensor x, Tensor y, string name = null) + => math_ops.multiply(x, y, name: name); + + public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) + => math_ops.div_no_nan(a, b); + + public Tensor square(Tensor x, string name = null) + => math_ops.square(x, name: name); + public Tensor sum(Tensor x, Axis? axis = null, string name = null) => math_ops.reduce_sum(x, axis: axis, name: name); diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs index e4575620..271ca6e1 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs @@ -71,6 +71,27 @@ public interface IMetricsApi TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null); + /// + /// Computes F-1 Score. + /// + /// + IMetricFunc F1Score(int num_classes, + string? average = null, + float threshold = -1f, + string name = "fbeta_score", + TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Computes F-Beta score. + /// + /// + IMetricFunc FBetaScore(int num_classes, + string? average = null, + float beta = 0.1f, + float threshold = -1f, + string name = "fbeta_score", + TF_DataType dtype = TF_DataType.TF_FLOAT); + /// /// Computes how often targets are in the top K predictions. /// diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index c2b53e76..ecb63a7b 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -5,7 +5,7 @@ Tensorflow.Binding Tensorflow 2.10.0 - 0.100.3 + 0.100.4 10.0 enable Haiping Chen, Meinrad Recheis, Eli Belash @@ -20,7 +20,7 @@ Google's TensorFlow full binding in .NET Standard. Building, training and infering deep learning models. https://tensorflownet.readthedocs.io - 0.100.3.0 + 0.100.4.0 tf.net 0.100.x and above are based on tensorflow native 2.10.0 @@ -38,7 +38,7 @@ https://tensorflownet.readthedocs.io tf.net 0.7x.x aligns with TensorFlow v2.7.x native library. tf.net 0.10x.x aligns with TensorFlow v2.10.x native library. - 0.100.3.0 + 0.100.4.0 LICENSE true true diff --git a/src/TensorFlowNET.Keras/Metrics/F1Score.cs b/src/TensorFlowNET.Keras/Metrics/F1Score.cs new file mode 100644 index 00000000..c3276f3e --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/F1Score.cs @@ -0,0 +1,13 @@ +namespace Tensorflow.Keras.Metrics; + +public class F1Score : FBetaScore +{ + public F1Score(int num_classes, + string? average = null, + float? threshold = -1f, + string name = "f1_score", + TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(num_classes, average: average, threshold: threshold, beta: 1f, name: name, dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/FBetaScore.cs b/src/TensorFlowNET.Keras/Metrics/FBetaScore.cs new file mode 100644 index 00000000..ab4d00a9 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/FBetaScore.cs @@ -0,0 +1,131 @@ +namespace Tensorflow.Keras.Metrics; + +public class FBetaScore : Metric +{ + int _num_classes; + string? _average; + Tensor _beta; + Tensor _threshold; + Axis _axis; + int[] _init_shape; + + IVariableV1 true_positives; + IVariableV1 false_positives; + IVariableV1 false_negatives; + IVariableV1 weights_intermediate; + + public FBetaScore(int num_classes, + string? average = null, + float beta = 0.1f, + float? threshold = -1f, + string name = "fbeta_score", + TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(name: name, dtype: dtype) + { + _num_classes = num_classes; + _average = average; + _beta = constant_op.constant(beta); + _dtype = dtype; + + if (threshold.HasValue) + { + _threshold = constant_op.constant(threshold); + } + + _init_shape = new int[0]; + + if (average != "micro") + { + _axis = 0; + _init_shape = new int[] { num_classes }; + } + + true_positives = add_weight("true_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); + false_positives = add_weight("false_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); + false_negatives = add_weight("false_negatives", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); + weights_intermediate = add_weight("weights_intermediate", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); + } + + public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + if (_threshold == null) + { + _threshold = tf.reduce_max(y_pred, axis: -1, keepdims: true); + // make sure [0, 0, 0] doesn't become [1, 1, 1] + // Use abs(x) > eps, instead of x != 0 to check for zero + y_pred = tf.logical_and(y_pred >= _threshold, tf.abs(y_pred) > 1e-12); + } + else + { + y_pred = y_pred > _threshold; + } + + y_true = tf.cast(y_true, _dtype); + y_pred = tf.cast(y_pred, _dtype); + + true_positives.assign_add(_weighted_sum(y_pred * y_true, sample_weight)); + false_positives.assign_add( + _weighted_sum(y_pred * (1 - y_true), sample_weight) + ); + false_negatives.assign_add( + _weighted_sum((1 - y_pred) * y_true, sample_weight) + ); + weights_intermediate.assign_add(_weighted_sum(y_true, sample_weight)); + + return weights_intermediate.AsTensor(); + } + + Tensor _weighted_sum(Tensor val, Tensor? sample_weight = null) + { + if (sample_weight != null) + { + val = tf.math.multiply(val, tf.expand_dims(sample_weight, 1)); + } + + return tf.reduce_sum(val, axis: _axis); + } + + public override Tensor result() + { + var precision = tf.math.divide_no_nan( + true_positives.AsTensor(), true_positives.AsTensor() + false_positives.AsTensor() + ); + var recall = tf.math.divide_no_nan( + true_positives.AsTensor(), true_positives.AsTensor() + false_negatives.AsTensor() + ); + + var mul_value = precision * recall; + var add_value = (tf.math.square(_beta) * precision) + recall; + var mean = tf.math.divide_no_nan(mul_value, add_value); + var f1_score = mean * (1 + tf.math.square(_beta)); + + Tensor weights; + if (_average == "weighted") + { + weights = tf.math.divide_no_nan( + weights_intermediate.AsTensor(), tf.reduce_sum(weights_intermediate.AsTensor()) + ); + f1_score = tf.reduce_sum(f1_score * weights); + } + // micro, macro + else if (_average != null) + { + f1_score = tf.reduce_mean(f1_score); + } + + return f1_score; + } + + public override void reset_states() + { + var reset_value = np.zeros(_init_shape, dtype: _dtype); + keras.backend.batch_set_value( + new List<(IVariableV1, NDArray)> + { + (true_positives, reset_value), + (false_positives, reset_value), + (false_negatives, reset_value), + (weights_intermediate, reset_value) + }); + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index e207d27d..5230fe59 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -86,6 +86,12 @@ public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null) => new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1); + public IMetricFunc F1Score(int num_classes, string? average = null, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new F1Score(num_classes, average: average, threshold: threshold, name: name, dtype: dtype); + + public IMetricFunc FBetaScore(int num_classes, string? average = null, float beta = 0.1F, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new FBetaScore(num_classes, average: average,beta: beta, threshold: threshold, name: name, dtype: dtype); + public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 264b9501..104e6433 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -7,7 +7,7 @@ enable Tensorflow.Keras AnyCPU;x64 - 0.10.3 + 0.10.4 Haiping Chen Keras for .NET Apache 2.0, Haiping Chen 2023 @@ -37,8 +37,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac Git true Open.snk - 0.10.3.0 - 0.10.3.0 + 0.10.4.0 + 0.10.4.0 LICENSE Debug;Release;GPU diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs index 90be51bd..2b38449b 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs @@ -114,6 +114,34 @@ public class MetricsTest : EagerModeTestBase Assert.AreEqual(r, 0.6999999f); } + /// + /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score + /// + [TestMethod] + public void F1Score() + { + var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } }); + var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } }); + var m = tf.keras.metrics.F1Score(num_classes: 3, threshold: 0.5f); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, new[] { 0.5f, 0.8f, 0.6666667f }); + } + + /// + /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/FBetaScore + /// + [TestMethod] + public void FBetaScore() + { + var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } }); + var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } }); + var m = tf.keras.metrics.FBetaScore(num_classes: 3, beta: 2.0f, threshold: 0.5f); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, new[] { 0.3846154f, 0.90909094f, 0.8333334f }); + } + /// /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy ///