diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
index 511b0ef1..95cc1e60 100644
--- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
@@ -35,4 +35,15 @@ public interface IMetricsApi
///
///
IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT);
+
+ ///
+ /// Computes the recall of the predictions with respect to the labels.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
}
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index 263509f6..0e888a0a 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -221,6 +221,9 @@ namespace Tensorflow
case Tensor t:
dtype = t.dtype.as_base_dtype();
break;
+ case int t:
+ dtype = TF_DataType.TF_INT32;
+ break;
}
if (dtype != TF_DataType.DtInvalid)
diff --git a/src/TensorFlowNET.Keras/GlobalUsing.cs b/src/TensorFlowNET.Keras/GlobalUsing.cs
index 72ff8b28..bc0798ed 100644
--- a/src/TensorFlowNET.Keras/GlobalUsing.cs
+++ b/src/TensorFlowNET.Keras/GlobalUsing.cs
@@ -1,5 +1,7 @@
global using System;
global using System.Collections.Generic;
global using System.Text;
+global using System.Linq;
global using static Tensorflow.Binding;
-global using static Tensorflow.KerasApi;
\ No newline at end of file
+global using static Tensorflow.KerasApi;
+global using Tensorflow.NumPy;
\ No newline at end of file
diff --git a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
index 2e985b88..7173aae1 100644
--- a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Metrics
y_true = math_ops.cast(y_true, _dtype);
y_pred = math_ops.cast(y_pred, _dtype);
- (y_pred, y_true) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true);
+ (y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true);
var matches = _fn(y_true, y_pred);
return update_state(matches, sample_weight: sample_weight);
diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
index dfccfdbb..23742832 100644
--- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
@@ -61,5 +61,8 @@
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
+
+ public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ => new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype);
}
}
diff --git a/src/TensorFlowNET.Keras/Metrics/Recall.cs b/src/TensorFlowNET.Keras/Metrics/Recall.cs
new file mode 100644
index 00000000..9b58bf5f
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/Recall.cs
@@ -0,0 +1,53 @@
+namespace Tensorflow.Keras.Metrics;
+
+public class Recall : Metric
+{
+ Tensor _thresholds;
+ int _top_k;
+ int _class_id;
+ IVariableV1 true_positives;
+ IVariableV1 false_negatives;
+ bool _thresholds_distributed_evenly;
+
+ public Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ : base(name: name, dtype: dtype)
+ {
+ _thresholds = constant_op.constant(new float[] { thresholds });
+ true_positives = add_weight("true_positives", shape: 1, initializer: tf.initializers.zeros_initializer());
+ false_negatives = add_weight("false_negatives", shape: 1, initializer: tf.initializers.zeros_initializer());
+ }
+
+ public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
+ {
+ return metrics_utils.update_confusion_matrix_variables(
+ new Dictionary
+ {
+ { "tp", true_positives },
+ { "fn", false_negatives },
+ },
+ y_true,
+ y_pred,
+ thresholds: _thresholds,
+ thresholds_distributed_evenly: _thresholds_distributed_evenly,
+ top_k: _top_k,
+ class_id: _class_id,
+ sample_weight: sample_weight);
+ }
+
+ public override Tensor result()
+ {
+ var result = tf.divide(true_positives.AsTensor(), tf.add(true_positives, false_negatives));
+ return _thresholds.size == 1 ? result[0] : result;
+ }
+
+ public override void reset_states()
+ {
+ var num_thresholds = (int)_thresholds.size;
+ keras.backend.batch_set_value(
+ new List<(IVariableV1, NDArray)>
+ {
+ (true_positives, np.zeros(num_thresholds)),
+ (false_negatives, np.zeros(num_thresholds))
+ });
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Metrics/Reduce.cs b/src/TensorFlowNET.Keras/Metrics/Reduce.cs
index f7cdb8f5..8874719d 100644
--- a/src/TensorFlowNET.Keras/Metrics/Reduce.cs
+++ b/src/TensorFlowNET.Keras/Metrics/Reduce.cs
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Metrics
{
if (sample_weight != null)
{
- (values, sample_weight) = losses_utils.squeeze_or_expand_dimensions(
+ (values, _, sample_weight) = losses_utils.squeeze_or_expand_dimensions(
values, sample_weight: sample_weight);
sample_weight = math_ops.cast(sample_weight, dtype: values.dtype);
diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
index de6a8402..d09b3c72 100644
--- a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
+++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
@@ -1,4 +1,5 @@
-using Tensorflow.NumPy;
+using Tensorflow.Keras.Utils;
+using Tensorflow.NumPy;
namespace Tensorflow.Keras.Metrics;
@@ -36,4 +37,172 @@ public class metrics_utils
return matches;
}
+
+ public static Tensor update_confusion_matrix_variables(Dictionary variables_to_update,
+ Tensor y_true,
+ Tensor y_pred,
+ Tensor thresholds,
+ int top_k,
+ int class_id,
+ Tensor sample_weight = null,
+ bool multi_label = false,
+ Tensor label_weights = null,
+ bool thresholds_distributed_evenly = false)
+ {
+ var variable_dtype = variables_to_update.Values.First().dtype;
+ y_true = tf.cast(y_true, dtype: variable_dtype);
+ y_pred = tf.cast(y_pred, dtype: variable_dtype);
+ var num_thresholds = thresholds.shape.dims[0];
+
+ Tensor one_thresh = null;
+ if (multi_label)
+ {
+ one_thresh = tf.equal(tf.cast(constant_op.constant(1), dtype:tf.int32),
+ tf.rank(thresholds),
+ name: "one_set_of_thresholds_cond");
+ }
+ else
+ {
+ one_thresh = tf.cast(constant_op.constant(true), dtype: dtypes.@bool);
+ }
+
+ if (sample_weight == null)
+ {
+ (y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true);
+ }
+ else
+ {
+ sample_weight = tf.cast(sample_weight, dtype: variable_dtype);
+ (y_pred, y_true, sample_weight) = losses_utils.squeeze_or_expand_dimensions(y_pred,
+ y_true,
+ sample_weight: sample_weight);
+ }
+
+ if (thresholds_distributed_evenly)
+ {
+ throw new NotImplementedException();
+ }
+
+ var pred_shape = tf.shape(y_pred);
+ var num_predictions = pred_shape[0];
+
+ Tensor num_labels;
+ if (y_pred.shape.ndim == 1)
+ {
+ num_labels = constant_op.constant(1);
+ }
+ else
+ {
+ num_labels = tf.reduce_prod(pred_shape["1:"], axis: 0);
+ }
+ var thresh_label_tile = tf.where(one_thresh, num_labels, tf.ones(new int[0], dtype: tf.int32));
+
+ // Reshape predictions and labels, adding a dim for thresholding.
+ Tensor predictions_extra_dim, labels_extra_dim;
+ if (multi_label)
+ {
+ predictions_extra_dim = tf.expand_dims(y_pred, 0);
+ labels_extra_dim = tf.expand_dims(tf.cast(y_true, dtype: tf.@bool), 0);
+ }
+
+ else
+ {
+ // Flatten predictions and labels when not multilabel.
+ predictions_extra_dim = tf.reshape(y_pred, (1, -1));
+ labels_extra_dim = tf.reshape(tf.cast(y_true, dtype: tf.@bool), (1, -1));
+ }
+
+ // Tile the thresholds for every prediction.
+ object[] thresh_pretile_shape, thresh_tiles, data_tiles;
+
+ if (multi_label)
+ {
+ thresh_pretile_shape = new object[] { num_thresholds, 1, -1 };
+ thresh_tiles = new object[] { 1, num_predictions, thresh_label_tile };
+ data_tiles = new object[] { num_thresholds, 1, 1 };
+ }
+ else
+ {
+ thresh_pretile_shape = new object[] { num_thresholds, -1 };
+ thresh_tiles = new object[] { 1, num_predictions * num_labels };
+ data_tiles = new object[] { num_thresholds, 1 };
+ }
+ var thresh_tiled = tf.tile(tf.reshape(thresholds, thresh_pretile_shape), tf.stack(thresh_tiles));
+
+ // Tile the predictions for every threshold.
+ var preds_tiled = tf.tile(predictions_extra_dim, data_tiles);
+
+ // Compare predictions and threshold.
+ var pred_is_pos = tf.greater(preds_tiled, thresh_tiled);
+
+ // Tile labels by number of thresholds
+ var label_is_pos = tf.tile(labels_extra_dim, data_tiles);
+
+ Tensor weights_tiled = null;
+
+ if (sample_weight != null)
+ {
+ /*sample_weight = broadcast_weights(
+ tf.cast(sample_weight, dtype: variable_dtype), y_pred);*/
+ weights_tiled = tf.tile(
+ tf.reshape(sample_weight, thresh_tiles), data_tiles);
+ }
+
+ if (label_weights != null && !multi_label)
+ {
+ throw new NotImplementedException();
+ }
+
+ Func weighted_assign_add
+ = (label, pred, weights, var) =>
+ {
+ var label_and_pred = tf.cast(tf.logical_and(label, pred), dtype: var.dtype);
+ if (weights != null)
+ {
+ label_and_pred *= tf.cast(weights, dtype: var.dtype);
+ }
+
+ return var.assign_add(tf.reduce_sum(label_and_pred, 1));
+ };
+
+
+ var loop_vars = new Dictionary
+ {
+ { "tp", (label_is_pos, pred_is_pos) }
+ };
+ var update_tn = variables_to_update.ContainsKey("tn");
+ var update_fp = variables_to_update.ContainsKey("fp");
+ var update_fn = variables_to_update.ContainsKey("fn");
+
+ Tensor pred_is_neg = null;
+ if (update_fn || update_tn)
+ {
+ pred_is_neg = tf.logical_not(pred_is_pos);
+ loop_vars["fn"] = (label_is_pos, pred_is_neg);
+ }
+
+ if(update_fp || update_tn)
+ {
+ var label_is_neg = tf.logical_not(label_is_pos);
+ loop_vars["fp"] = (label_is_neg, pred_is_pos);
+ if (update_tn)
+ {
+ loop_vars["tn"] = (label_is_neg, pred_is_neg);
+ }
+ }
+
+ var update_ops = new List();
+ foreach (var matrix_cond in loop_vars.Keys)
+ {
+ var (label, pred) = loop_vars[matrix_cond];
+ if (variables_to_update.ContainsKey(matrix_cond))
+ {
+ var op = weighted_assign_add(label, pred, weights_tiled, variables_to_update[matrix_cond]);
+ update_ops.append(op);
+ }
+ }
+
+ tf.group(update_ops.ToArray());
+ return null;
+ }
}
diff --git a/src/TensorFlowNET.Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Keras/Utils/losses_utils.cs
index 6de98861..717acf5e 100644
--- a/src/TensorFlowNET.Keras/Utils/losses_utils.cs
+++ b/src/TensorFlowNET.Keras/Utils/losses_utils.cs
@@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Utils
});
}
- public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null)
+ public static (Tensor, Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null)
{
var y_pred_shape = y_pred.shape;
var y_pred_rank = y_pred_shape.ndim;
@@ -57,13 +57,13 @@ namespace Tensorflow.Keras.Utils
if (sample_weight == null)
{
- return (y_pred, y_true);
+ return (y_pred, y_true, sample_weight);
}
var weights_shape = sample_weight.shape;
var weights_rank = weights_shape.ndim;
if (weights_rank == 0)
- return (y_pred, sample_weight);
+ return (y_pred, y_true, sample_weight);
if (y_pred_rank > -1 && weights_rank > -1)
{
@@ -77,7 +77,7 @@ namespace Tensorflow.Keras.Utils
}
else
{
- return (y_pred, sample_weight);
+ return (y_pred, y_true, sample_weight);
}
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
index a35763d0..84382bb4 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
@@ -45,4 +45,24 @@ public class MetricsTest : EagerModeTestBase
var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3);
Assert.AreEqual(m.numpy(), new[] { 1f, 1f });
}
+
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall
+ ///
+ [TestMethod]
+ public void Recall()
+ {
+ var y_true = np.array(new[] { 0, 1, 1, 1 });
+ var y_pred = np.array(new[] { 1, 0, 1, 1 });
+ var m = tf.keras.metrics.Recall();
+ m.update_state(y_true, y_pred);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 0.6666667f);
+
+ m.reset_states();
+ var weights = np.array(new[] { 0f, 0f, 1f, 0f });
+ m.update_state(y_true, y_pred, sample_weight: weights);
+ r = m.result().numpy();
+ Assert.AreEqual(r, 1f);
+ }
}