| @@ -0,0 +1,17 @@ | |||||
| namespace Tensorflow.Keras.Metrics; | |||||
| public interface IMetricFunc | |||||
| { | |||||
| /// <summary> | |||||
| /// Accumulates metric statistics. | |||||
| /// </summary> | |||||
| /// <param name="y_true"></param> | |||||
| /// <param name="y_pred"></param> | |||||
| /// <param name="sample_weight"></param> | |||||
| /// <returns></returns> | |||||
| Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null); | |||||
| Tensor result(); | |||||
| void reset_states(); | |||||
| } | |||||
| @@ -26,4 +26,13 @@ public interface IMetricsApi | |||||
| /// <param name="k"></param> | /// <param name="k"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5); | Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5); | ||||
| /// <summary> | |||||
| /// Computes how often targets are in the top K predictions. | |||||
| /// </summary> | |||||
| /// <param name="y_true"></param> | |||||
| /// <param name="y_pred"></param> | |||||
| /// <param name="k"></param> | |||||
| /// <returns></returns> | |||||
| IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Keras.Utils; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| @@ -17,6 +18,8 @@ namespace Tensorflow.Keras.Metrics | |||||
| y_true = math_ops.cast(y_true, _dtype); | y_true = math_ops.cast(y_true, _dtype); | ||||
| y_pred = math_ops.cast(y_pred, _dtype); | y_pred = math_ops.cast(y_pred, _dtype); | ||||
| (y_pred, y_true) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true); | |||||
| var matches = _fn(y_true, y_pred); | var matches = _fn(y_true, y_pred); | ||||
| return update_state(matches, sample_weight: sample_weight); | return update_state(matches, sample_weight: sample_weight); | ||||
| } | } | ||||
| @@ -9,7 +9,7 @@ namespace Tensorflow.Keras.Metrics | |||||
| /// <summary> | /// <summary> | ||||
| /// Encapsulates metric logic and state. | /// Encapsulates metric logic and state. | ||||
| /// </summary> | /// </summary> | ||||
| public class Metric : Layer | |||||
| public class Metric : Layer, IMetricFunc | |||||
| { | { | ||||
| protected IVariableV1 total; | protected IVariableV1 total; | ||||
| protected IVariableV1 count; | protected IVariableV1 count; | ||||
| @@ -1,6 +1,4 @@ | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Metrics | |||||
| namespace Tensorflow.Keras.Metrics | |||||
| { | { | ||||
| public class MetricsApi : IMetricsApi | public class MetricsApi : IMetricsApi | ||||
| { | { | ||||
| @@ -60,5 +58,8 @@ namespace Tensorflow.Keras.Metrics | |||||
| tf.math.argmax(y_true, axis: -1), y_pred, k | tf.math.argmax(y_true, axis: -1), y_pred, k | ||||
| ); | ); | ||||
| } | } | ||||
| public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,12 @@ | |||||
| namespace Tensorflow.Keras.Metrics; | |||||
| public class TopKCategoricalAccuracy : MeanMetricWrapper | |||||
| { | |||||
| public TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| : base((yt, yp) => metrics_utils.sparse_top_k_categorical_matches( | |||||
| tf.math.argmax(yt, axis: -1), yp, k), | |||||
| name: name, | |||||
| dtype: dtype) | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Xml.Linq; | |||||
| using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -37,15 +38,57 @@ namespace Tensorflow.Keras.Utils | |||||
| }); | }); | ||||
| } | } | ||||
| public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight) | |||||
| public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) | |||||
| { | { | ||||
| var y_pred_shape = y_pred.shape; | |||||
| var y_pred_rank = y_pred_shape.ndim; | |||||
| if (y_true != null) | |||||
| { | |||||
| var y_true_shape = y_true.shape; | |||||
| var y_true_rank = y_true_shape.ndim; | |||||
| if (y_true_rank > -1 && y_pred_rank > -1) | |||||
| { | |||||
| if (y_pred_rank - y_true_rank != 1 || y_pred_shape[-1] == 1) | |||||
| { | |||||
| (y_true, y_pred) = remove_squeezable_dimensions(y_true, y_pred); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (sample_weight == null) | |||||
| { | |||||
| return (y_pred, y_true); | |||||
| } | |||||
| var weights_shape = sample_weight.shape; | var weights_shape = sample_weight.shape; | ||||
| var weights_rank = weights_shape.ndim; | var weights_rank = weights_shape.ndim; | ||||
| if (weights_rank == 0) | if (weights_rank == 0) | ||||
| return (y_pred, sample_weight); | return (y_pred, sample_weight); | ||||
| if (y_pred_rank > -1 && weights_rank > -1) | |||||
| { | |||||
| if (weights_rank - y_pred_rank == 1) | |||||
| { | |||||
| sample_weight = tf.squeeze(sample_weight, -1); | |||||
| } | |||||
| else if (y_pred_rank - weights_rank == 1) | |||||
| { | |||||
| sample_weight = tf.expand_dims(sample_weight, -1); | |||||
| } | |||||
| else | |||||
| { | |||||
| return (y_pred, sample_weight); | |||||
| } | |||||
| } | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| public static (Tensor, Tensor) remove_squeezable_dimensions(Tensor labels, Tensor predictions, int expected_rank_diff = 0, string name = null) | |||||
| { | |||||
| return (labels, predictions); | |||||
| } | |||||
| public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) | public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) | ||||
| { | { | ||||
| if (reduction == ReductionV2.NONE) | if (reduction == ReductionV2.NONE) | ||||
| @@ -14,6 +14,26 @@ namespace TensorFlowNET.Keras.UnitTest; | |||||
| [TestClass] | [TestClass] | ||||
| public class MetricsTest : EagerModeTestBase | public class MetricsTest : EagerModeTestBase | ||||
| { | { | ||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void TopKCategoricalAccuracy() | |||||
| { | |||||
| var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } }); | |||||
| var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); | |||||
| var m = tf.keras.metrics.TopKCategoricalAccuracy(k: 1); | |||||
| m.update_state(y_true, y_pred); | |||||
| var r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.5f); | |||||
| m.reset_states(); | |||||
| var weights = np.array(new[] { 0.7f, 0.3f }); | |||||
| m.update_state(y_true, y_pred, sample_weight: weights); | |||||
| r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.3f); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy | /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy | ||||
| /// </summary> | /// </summary> | ||||