| @@ -35,4 +35,15 @@ public interface IMetricsApi | |||||
| /// <param name="k"></param> | /// <param name="k"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); | IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); | ||||
| /// <summary> | |||||
| /// Computes the recall of the predictions with respect to the labels. | |||||
| /// </summary> | |||||
| /// <param name="thresholds"></param> | |||||
| /// <param name="top_k"></param> | |||||
| /// <param name="class_id"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <returns></returns> | |||||
| IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| } | } | ||||
| @@ -221,6 +221,9 @@ namespace Tensorflow | |||||
| case Tensor t: | case Tensor t: | ||||
| dtype = t.dtype.as_base_dtype(); | dtype = t.dtype.as_base_dtype(); | ||||
| break; | break; | ||||
| case int t: | |||||
| dtype = TF_DataType.TF_INT32; | |||||
| break; | |||||
| } | } | ||||
| if (dtype != TF_DataType.DtInvalid) | if (dtype != TF_DataType.DtInvalid) | ||||
| @@ -1,5 +1,7 @@ | |||||
| global using System; | global using System; | ||||
| global using System.Collections.Generic; | global using System.Collections.Generic; | ||||
| global using System.Text; | global using System.Text; | ||||
| global using System.Linq; | |||||
| global using static Tensorflow.Binding; | global using static Tensorflow.Binding; | ||||
| global using static Tensorflow.KerasApi; | |||||
| global using static Tensorflow.KerasApi; | |||||
| global using Tensorflow.NumPy; | |||||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Metrics | |||||
| y_true = math_ops.cast(y_true, _dtype); | y_true = math_ops.cast(y_true, _dtype); | ||||
| y_pred = math_ops.cast(y_pred, _dtype); | y_pred = math_ops.cast(y_pred, _dtype); | ||||
| (y_pred, y_true) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true); | |||||
| (y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true); | |||||
| var matches = _fn(y_true, y_pred); | var matches = _fn(y_true, y_pred); | ||||
| return update_state(matches, sample_weight: sample_weight); | return update_state(matches, sample_weight: sample_weight); | ||||
| @@ -61,5 +61,8 @@ | |||||
| public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
| => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); | => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); | ||||
| public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| => new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,53 @@ | |||||
| namespace Tensorflow.Keras.Metrics; | |||||
| public class Recall : Metric | |||||
| { | |||||
| Tensor _thresholds; | |||||
| int _top_k; | |||||
| int _class_id; | |||||
| IVariableV1 true_positives; | |||||
| IVariableV1 false_negatives; | |||||
| bool _thresholds_distributed_evenly; | |||||
| public Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| : base(name: name, dtype: dtype) | |||||
| { | |||||
| _thresholds = constant_op.constant(new float[] { thresholds }); | |||||
| true_positives = add_weight("true_positives", shape: 1, initializer: tf.initializers.zeros_initializer()); | |||||
| false_negatives = add_weight("false_negatives", shape: 1, initializer: tf.initializers.zeros_initializer()); | |||||
| } | |||||
| public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||||
| { | |||||
| return metrics_utils.update_confusion_matrix_variables( | |||||
| new Dictionary<string, IVariableV1> | |||||
| { | |||||
| { "tp", true_positives }, | |||||
| { "fn", false_negatives }, | |||||
| }, | |||||
| y_true, | |||||
| y_pred, | |||||
| thresholds: _thresholds, | |||||
| thresholds_distributed_evenly: _thresholds_distributed_evenly, | |||||
| top_k: _top_k, | |||||
| class_id: _class_id, | |||||
| sample_weight: sample_weight); | |||||
| } | |||||
| public override Tensor result() | |||||
| { | |||||
| var result = tf.divide(true_positives.AsTensor(), tf.add(true_positives, false_negatives)); | |||||
| return _thresholds.size == 1 ? result[0] : result; | |||||
| } | |||||
| public override void reset_states() | |||||
| { | |||||
| var num_thresholds = (int)_thresholds.size; | |||||
| keras.backend.batch_set_value( | |||||
| new List<(IVariableV1, NDArray)> | |||||
| { | |||||
| (true_positives, np.zeros(num_thresholds)), | |||||
| (false_negatives, np.zeros(num_thresholds)) | |||||
| }); | |||||
| } | |||||
| } | |||||
| @@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Metrics | |||||
| { | { | ||||
| if (sample_weight != null) | if (sample_weight != null) | ||||
| { | { | ||||
| (values, sample_weight) = losses_utils.squeeze_or_expand_dimensions( | |||||
| (values, _, sample_weight) = losses_utils.squeeze_or_expand_dimensions( | |||||
| values, sample_weight: sample_weight); | values, sample_weight: sample_weight); | ||||
| sample_weight = math_ops.cast(sample_weight, dtype: values.dtype); | sample_weight = math_ops.cast(sample_weight, dtype: values.dtype); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Tensorflow.NumPy; | |||||
| using Tensorflow.Keras.Utils; | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.Metrics; | namespace Tensorflow.Keras.Metrics; | ||||
| @@ -36,4 +37,172 @@ public class metrics_utils | |||||
| return matches; | return matches; | ||||
| } | } | ||||
| public static Tensor update_confusion_matrix_variables(Dictionary<string, IVariableV1> variables_to_update, | |||||
| Tensor y_true, | |||||
| Tensor y_pred, | |||||
| Tensor thresholds, | |||||
| int top_k, | |||||
| int class_id, | |||||
| Tensor sample_weight = null, | |||||
| bool multi_label = false, | |||||
| Tensor label_weights = null, | |||||
| bool thresholds_distributed_evenly = false) | |||||
| { | |||||
| var variable_dtype = variables_to_update.Values.First().dtype; | |||||
| y_true = tf.cast(y_true, dtype: variable_dtype); | |||||
| y_pred = tf.cast(y_pred, dtype: variable_dtype); | |||||
| var num_thresholds = thresholds.shape.dims[0]; | |||||
| Tensor one_thresh = null; | |||||
| if (multi_label) | |||||
| { | |||||
| one_thresh = tf.equal(tf.cast(constant_op.constant(1), dtype:tf.int32), | |||||
| tf.rank(thresholds), | |||||
| name: "one_set_of_thresholds_cond"); | |||||
| } | |||||
| else | |||||
| { | |||||
| one_thresh = tf.cast(constant_op.constant(true), dtype: dtypes.@bool); | |||||
| } | |||||
| if (sample_weight == null) | |||||
| { | |||||
| (y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true); | |||||
| } | |||||
| else | |||||
| { | |||||
| sample_weight = tf.cast(sample_weight, dtype: variable_dtype); | |||||
| (y_pred, y_true, sample_weight) = losses_utils.squeeze_or_expand_dimensions(y_pred, | |||||
| y_true, | |||||
| sample_weight: sample_weight); | |||||
| } | |||||
| if (thresholds_distributed_evenly) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| var pred_shape = tf.shape(y_pred); | |||||
| var num_predictions = pred_shape[0]; | |||||
| Tensor num_labels; | |||||
| if (y_pred.shape.ndim == 1) | |||||
| { | |||||
| num_labels = constant_op.constant(1); | |||||
| } | |||||
| else | |||||
| { | |||||
| num_labels = tf.reduce_prod(pred_shape["1:"], axis: 0); | |||||
| } | |||||
| var thresh_label_tile = tf.where(one_thresh, num_labels, tf.ones(new int[0], dtype: tf.int32)); | |||||
| // Reshape predictions and labels, adding a dim for thresholding. | |||||
| Tensor predictions_extra_dim, labels_extra_dim; | |||||
| if (multi_label) | |||||
| { | |||||
| predictions_extra_dim = tf.expand_dims(y_pred, 0); | |||||
| labels_extra_dim = tf.expand_dims(tf.cast(y_true, dtype: tf.@bool), 0); | |||||
| } | |||||
| else | |||||
| { | |||||
| // Flatten predictions and labels when not multilabel. | |||||
| predictions_extra_dim = tf.reshape(y_pred, (1, -1)); | |||||
| labels_extra_dim = tf.reshape(tf.cast(y_true, dtype: tf.@bool), (1, -1)); | |||||
| } | |||||
| // Tile the thresholds for every prediction. | |||||
| object[] thresh_pretile_shape, thresh_tiles, data_tiles; | |||||
| if (multi_label) | |||||
| { | |||||
| thresh_pretile_shape = new object[] { num_thresholds, 1, -1 }; | |||||
| thresh_tiles = new object[] { 1, num_predictions, thresh_label_tile }; | |||||
| data_tiles = new object[] { num_thresholds, 1, 1 }; | |||||
| } | |||||
| else | |||||
| { | |||||
| thresh_pretile_shape = new object[] { num_thresholds, -1 }; | |||||
| thresh_tiles = new object[] { 1, num_predictions * num_labels }; | |||||
| data_tiles = new object[] { num_thresholds, 1 }; | |||||
| } | |||||
| var thresh_tiled = tf.tile(tf.reshape(thresholds, thresh_pretile_shape), tf.stack(thresh_tiles)); | |||||
| // Tile the predictions for every threshold. | |||||
| var preds_tiled = tf.tile(predictions_extra_dim, data_tiles); | |||||
| // Compare predictions and threshold. | |||||
| var pred_is_pos = tf.greater(preds_tiled, thresh_tiled); | |||||
| // Tile labels by number of thresholds | |||||
| var label_is_pos = tf.tile(labels_extra_dim, data_tiles); | |||||
| Tensor weights_tiled = null; | |||||
| if (sample_weight != null) | |||||
| { | |||||
| /*sample_weight = broadcast_weights( | |||||
| tf.cast(sample_weight, dtype: variable_dtype), y_pred);*/ | |||||
| weights_tiled = tf.tile( | |||||
| tf.reshape(sample_weight, thresh_tiles), data_tiles); | |||||
| } | |||||
| if (label_weights != null && !multi_label) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| Func<Tensor, Tensor, Tensor, IVariableV1, ITensorOrOperation> weighted_assign_add | |||||
| = (label, pred, weights, var) => | |||||
| { | |||||
| var label_and_pred = tf.cast(tf.logical_and(label, pred), dtype: var.dtype); | |||||
| if (weights != null) | |||||
| { | |||||
| label_and_pred *= tf.cast(weights, dtype: var.dtype); | |||||
| } | |||||
| return var.assign_add(tf.reduce_sum(label_and_pred, 1)); | |||||
| }; | |||||
| var loop_vars = new Dictionary<string, (Tensor, Tensor)> | |||||
| { | |||||
| { "tp", (label_is_pos, pred_is_pos) } | |||||
| }; | |||||
| var update_tn = variables_to_update.ContainsKey("tn"); | |||||
| var update_fp = variables_to_update.ContainsKey("fp"); | |||||
| var update_fn = variables_to_update.ContainsKey("fn"); | |||||
| Tensor pred_is_neg = null; | |||||
| if (update_fn || update_tn) | |||||
| { | |||||
| pred_is_neg = tf.logical_not(pred_is_pos); | |||||
| loop_vars["fn"] = (label_is_pos, pred_is_neg); | |||||
| } | |||||
| if(update_fp || update_tn) | |||||
| { | |||||
| var label_is_neg = tf.logical_not(label_is_pos); | |||||
| loop_vars["fp"] = (label_is_neg, pred_is_pos); | |||||
| if (update_tn) | |||||
| { | |||||
| loop_vars["tn"] = (label_is_neg, pred_is_neg); | |||||
| } | |||||
| } | |||||
| var update_ops = new List<ITensorOrOperation>(); | |||||
| foreach (var matrix_cond in loop_vars.Keys) | |||||
| { | |||||
| var (label, pred) = loop_vars[matrix_cond]; | |||||
| if (variables_to_update.ContainsKey(matrix_cond)) | |||||
| { | |||||
| var op = weighted_assign_add(label, pred, weights_tiled, variables_to_update[matrix_cond]); | |||||
| update_ops.append(op); | |||||
| } | |||||
| } | |||||
| tf.group(update_ops.ToArray()); | |||||
| return null; | |||||
| } | |||||
| } | } | ||||
| @@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Utils | |||||
| }); | }); | ||||
| } | } | ||||
| public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) | |||||
| public static (Tensor, Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) | |||||
| { | { | ||||
| var y_pred_shape = y_pred.shape; | var y_pred_shape = y_pred.shape; | ||||
| var y_pred_rank = y_pred_shape.ndim; | var y_pred_rank = y_pred_shape.ndim; | ||||
| @@ -57,13 +57,13 @@ namespace Tensorflow.Keras.Utils | |||||
| if (sample_weight == null) | if (sample_weight == null) | ||||
| { | { | ||||
| return (y_pred, y_true); | |||||
| return (y_pred, y_true, sample_weight); | |||||
| } | } | ||||
| var weights_shape = sample_weight.shape; | var weights_shape = sample_weight.shape; | ||||
| var weights_rank = weights_shape.ndim; | var weights_rank = weights_shape.ndim; | ||||
| if (weights_rank == 0) | if (weights_rank == 0) | ||||
| return (y_pred, sample_weight); | |||||
| return (y_pred, y_true, sample_weight); | |||||
| if (y_pred_rank > -1 && weights_rank > -1) | if (y_pred_rank > -1 && weights_rank > -1) | ||||
| { | { | ||||
| @@ -77,7 +77,7 @@ namespace Tensorflow.Keras.Utils | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return (y_pred, sample_weight); | |||||
| return (y_pred, y_true, sample_weight); | |||||
| } | } | ||||
| } | } | ||||
| @@ -45,4 +45,24 @@ public class MetricsTest : EagerModeTestBase | |||||
| var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3); | var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3); | ||||
| Assert.AreEqual(m.numpy(), new[] { 1f, 1f }); | Assert.AreEqual(m.numpy(), new[] { 1f, 1f }); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void Recall() | |||||
| { | |||||
| var y_true = np.array(new[] { 0, 1, 1, 1 }); | |||||
| var y_pred = np.array(new[] { 1, 0, 1, 1 }); | |||||
| var m = tf.keras.metrics.Recall(); | |||||
| m.update_state(y_true, y_pred); | |||||
| var r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.6666667f); | |||||
| m.reset_states(); | |||||
| var weights = np.array(new[] { 0f, 0f, 1f, 0f }); | |||||
| m.update_state(y_true, y_pred, sample_weight: weights); | |||||
| r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 1f); | |||||
| } | |||||
| } | } | ||||