diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs index 511b0ef1..95cc1e60 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs @@ -35,4 +35,15 @@ public interface IMetricsApi /// /// IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Computes the recall of the predictions with respect to the labels. + /// + /// + /// + /// + /// + /// + /// + IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT); } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 263509f6..0e888a0a 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -221,6 +221,9 @@ namespace Tensorflow case Tensor t: dtype = t.dtype.as_base_dtype(); break; + case int t: + dtype = TF_DataType.TF_INT32; + break; } if (dtype != TF_DataType.DtInvalid) diff --git a/src/TensorFlowNET.Keras/GlobalUsing.cs b/src/TensorFlowNET.Keras/GlobalUsing.cs index 72ff8b28..bc0798ed 100644 --- a/src/TensorFlowNET.Keras/GlobalUsing.cs +++ b/src/TensorFlowNET.Keras/GlobalUsing.cs @@ -1,5 +1,7 @@ global using System; global using System.Collections.Generic; global using System.Text; +global using System.Linq; global using static Tensorflow.Binding; -global using static Tensorflow.KerasApi; \ No newline at end of file +global using static Tensorflow.KerasApi; +global using Tensorflow.NumPy; \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs index 2e985b88..7173aae1 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Metrics y_true = math_ops.cast(y_true, _dtype); y_pred = math_ops.cast(y_pred, _dtype); - (y_pred, y_true) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true); + (y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true); var matches = _fn(y_true, y_pred); return update_state(matches, sample_weight: sample_weight); diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index dfccfdbb..23742832 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -61,5 +61,8 @@ public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); + + public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype); } } diff --git a/src/TensorFlowNET.Keras/Metrics/Recall.cs b/src/TensorFlowNET.Keras/Metrics/Recall.cs new file mode 100644 index 00000000..9b58bf5f --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/Recall.cs @@ -0,0 +1,53 @@ +namespace Tensorflow.Keras.Metrics; + +public class Recall : Metric +{ + Tensor _thresholds; + int _top_k; + int _class_id; + IVariableV1 true_positives; + IVariableV1 false_negatives; + bool _thresholds_distributed_evenly; + + public Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(name: name, dtype: dtype) + { + _thresholds = constant_op.constant(new float[] { thresholds }); + true_positives = add_weight("true_positives", shape: 1, initializer: tf.initializers.zeros_initializer()); + false_negatives = add_weight("false_negatives", shape: 1, initializer: tf.initializers.zeros_initializer()); + } + + public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + return metrics_utils.update_confusion_matrix_variables( + new Dictionary + { + { "tp", true_positives }, + { "fn", false_negatives }, + }, + y_true, + y_pred, + thresholds: _thresholds, + thresholds_distributed_evenly: _thresholds_distributed_evenly, + top_k: _top_k, + class_id: _class_id, + sample_weight: sample_weight); + } + + public override Tensor result() + { + var result = tf.divide(true_positives.AsTensor(), tf.add(true_positives, false_negatives)); + return _thresholds.size == 1 ? result[0] : result; + } + + public override void reset_states() + { + var num_thresholds = (int)_thresholds.size; + keras.backend.batch_set_value( + new List<(IVariableV1, NDArray)> + { + (true_positives, np.zeros(num_thresholds)), + (false_negatives, np.zeros(num_thresholds)) + }); + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/Reduce.cs b/src/TensorFlowNET.Keras/Metrics/Reduce.cs index f7cdb8f5..8874719d 100644 --- a/src/TensorFlowNET.Keras/Metrics/Reduce.cs +++ b/src/TensorFlowNET.Keras/Metrics/Reduce.cs @@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Metrics { if (sample_weight != null) { - (values, sample_weight) = losses_utils.squeeze_or_expand_dimensions( + (values, _, sample_weight) = losses_utils.squeeze_or_expand_dimensions( values, sample_weight: sample_weight); sample_weight = math_ops.cast(sample_weight, dtype: values.dtype); diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs index de6a8402..d09b3c72 100644 --- a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs +++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs @@ -1,4 +1,5 @@ -using Tensorflow.NumPy; +using Tensorflow.Keras.Utils; +using Tensorflow.NumPy; namespace Tensorflow.Keras.Metrics; @@ -36,4 +37,172 @@ public class metrics_utils return matches; } + + public static Tensor update_confusion_matrix_variables(Dictionary variables_to_update, + Tensor y_true, + Tensor y_pred, + Tensor thresholds, + int top_k, + int class_id, + Tensor sample_weight = null, + bool multi_label = false, + Tensor label_weights = null, + bool thresholds_distributed_evenly = false) + { + var variable_dtype = variables_to_update.Values.First().dtype; + y_true = tf.cast(y_true, dtype: variable_dtype); + y_pred = tf.cast(y_pred, dtype: variable_dtype); + var num_thresholds = thresholds.shape.dims[0]; + + Tensor one_thresh = null; + if (multi_label) + { + one_thresh = tf.equal(tf.cast(constant_op.constant(1), dtype:tf.int32), + tf.rank(thresholds), + name: "one_set_of_thresholds_cond"); + } + else + { + one_thresh = tf.cast(constant_op.constant(true), dtype: dtypes.@bool); + } + + if (sample_weight == null) + { + (y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true); + } + else + { + sample_weight = tf.cast(sample_weight, dtype: variable_dtype); + (y_pred, y_true, sample_weight) = losses_utils.squeeze_or_expand_dimensions(y_pred, + y_true, + sample_weight: sample_weight); + } + + if (thresholds_distributed_evenly) + { + throw new NotImplementedException(); + } + + var pred_shape = tf.shape(y_pred); + var num_predictions = pred_shape[0]; + + Tensor num_labels; + if (y_pred.shape.ndim == 1) + { + num_labels = constant_op.constant(1); + } + else + { + num_labels = tf.reduce_prod(pred_shape["1:"], axis: 0); + } + var thresh_label_tile = tf.where(one_thresh, num_labels, tf.ones(new int[0], dtype: tf.int32)); + + // Reshape predictions and labels, adding a dim for thresholding. + Tensor predictions_extra_dim, labels_extra_dim; + if (multi_label) + { + predictions_extra_dim = tf.expand_dims(y_pred, 0); + labels_extra_dim = tf.expand_dims(tf.cast(y_true, dtype: tf.@bool), 0); + } + + else + { + // Flatten predictions and labels when not multilabel. + predictions_extra_dim = tf.reshape(y_pred, (1, -1)); + labels_extra_dim = tf.reshape(tf.cast(y_true, dtype: tf.@bool), (1, -1)); + } + + // Tile the thresholds for every prediction. + object[] thresh_pretile_shape, thresh_tiles, data_tiles; + + if (multi_label) + { + thresh_pretile_shape = new object[] { num_thresholds, 1, -1 }; + thresh_tiles = new object[] { 1, num_predictions, thresh_label_tile }; + data_tiles = new object[] { num_thresholds, 1, 1 }; + } + else + { + thresh_pretile_shape = new object[] { num_thresholds, -1 }; + thresh_tiles = new object[] { 1, num_predictions * num_labels }; + data_tiles = new object[] { num_thresholds, 1 }; + } + var thresh_tiled = tf.tile(tf.reshape(thresholds, thresh_pretile_shape), tf.stack(thresh_tiles)); + + // Tile the predictions for every threshold. + var preds_tiled = tf.tile(predictions_extra_dim, data_tiles); + + // Compare predictions and threshold. + var pred_is_pos = tf.greater(preds_tiled, thresh_tiled); + + // Tile labels by number of thresholds + var label_is_pos = tf.tile(labels_extra_dim, data_tiles); + + Tensor weights_tiled = null; + + if (sample_weight != null) + { + /*sample_weight = broadcast_weights( + tf.cast(sample_weight, dtype: variable_dtype), y_pred);*/ + weights_tiled = tf.tile( + tf.reshape(sample_weight, thresh_tiles), data_tiles); + } + + if (label_weights != null && !multi_label) + { + throw new NotImplementedException(); + } + + Func weighted_assign_add + = (label, pred, weights, var) => + { + var label_and_pred = tf.cast(tf.logical_and(label, pred), dtype: var.dtype); + if (weights != null) + { + label_and_pred *= tf.cast(weights, dtype: var.dtype); + } + + return var.assign_add(tf.reduce_sum(label_and_pred, 1)); + }; + + + var loop_vars = new Dictionary + { + { "tp", (label_is_pos, pred_is_pos) } + }; + var update_tn = variables_to_update.ContainsKey("tn"); + var update_fp = variables_to_update.ContainsKey("fp"); + var update_fn = variables_to_update.ContainsKey("fn"); + + Tensor pred_is_neg = null; + if (update_fn || update_tn) + { + pred_is_neg = tf.logical_not(pred_is_pos); + loop_vars["fn"] = (label_is_pos, pred_is_neg); + } + + if(update_fp || update_tn) + { + var label_is_neg = tf.logical_not(label_is_pos); + loop_vars["fp"] = (label_is_neg, pred_is_pos); + if (update_tn) + { + loop_vars["tn"] = (label_is_neg, pred_is_neg); + } + } + + var update_ops = new List(); + foreach (var matrix_cond in loop_vars.Keys) + { + var (label, pred) = loop_vars[matrix_cond]; + if (variables_to_update.ContainsKey(matrix_cond)) + { + var op = weighted_assign_add(label, pred, weights_tiled, variables_to_update[matrix_cond]); + update_ops.append(op); + } + } + + tf.group(update_ops.ToArray()); + return null; + } } diff --git a/src/TensorFlowNET.Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Keras/Utils/losses_utils.cs index 6de98861..717acf5e 100644 --- a/src/TensorFlowNET.Keras/Utils/losses_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/losses_utils.cs @@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Utils }); } - public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) + public static (Tensor, Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) { var y_pred_shape = y_pred.shape; var y_pred_rank = y_pred_shape.ndim; @@ -57,13 +57,13 @@ namespace Tensorflow.Keras.Utils if (sample_weight == null) { - return (y_pred, y_true); + return (y_pred, y_true, sample_weight); } var weights_shape = sample_weight.shape; var weights_rank = weights_shape.ndim; if (weights_rank == 0) - return (y_pred, sample_weight); + return (y_pred, y_true, sample_weight); if (y_pred_rank > -1 && weights_rank > -1) { @@ -77,7 +77,7 @@ namespace Tensorflow.Keras.Utils } else { - return (y_pred, sample_weight); + return (y_pred, y_true, sample_weight); } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs index a35763d0..84382bb4 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs @@ -45,4 +45,24 @@ public class MetricsTest : EagerModeTestBase var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3); Assert.AreEqual(m.numpy(), new[] { 1f, 1f }); } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall + /// + [TestMethod] + public void Recall() + { + var y_true = np.array(new[] { 0, 1, 1, 1 }); + var y_pred = np.array(new[] { 1, 0, 1, 1 }); + var m = tf.keras.metrics.Recall(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.6666667f); + + m.reset_states(); + var weights = np.array(new[] { 0f, 0f, 1f, 0f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 1f); + } }