| @@ -22,6 +22,20 @@ public interface IMetricsApi | |||||
| /// <returns>Sparse categorical accuracy values.</returns> | /// <returns>Sparse categorical accuracy values.</returns> | ||||
| Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred); | Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred); | ||||
| /// <summary> | |||||
| /// Computes the sparse categorical crossentropy loss. | |||||
| /// </summary> | |||||
| /// <param name="y_true"></param> | |||||
| /// <param name="y_pred"></param> | |||||
| /// <param name="from_logits"></param> | |||||
| /// <param name="ignore_class"></param> | |||||
| /// <param name="axis"></param> | |||||
| /// <returns></returns> | |||||
| Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred, | |||||
| bool from_logits = false, | |||||
| int? ignore_class = null, | |||||
| Axis? axis = null); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes how often targets are in the top `K` predictions. | /// Computes how often targets are in the top `K` predictions. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -56,6 +70,16 @@ public interface IMetricsApi | |||||
| float label_smoothing = 0f, | float label_smoothing = 0f, | ||||
| Axis? axis = null); | Axis? axis = null); | ||||
| /// <summary> | |||||
| /// Computes the crossentropy metric between the labels and predictions. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| IMetricFunc SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| bool from_logits = false, | |||||
| int? ignore_class = null, | |||||
| Axis? axis = null); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes the crossentropy metric between the labels and predictions. | /// Computes the crossentropy metric between the labels and predictions. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -63,6 +87,13 @@ public interface IMetricsApi | |||||
| IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", | IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT); | TF_DataType dtype = TF_DataType.TF_FLOAT); | ||||
| /// <summary> | |||||
| /// Calculates how often predictions match integer labels. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| IMetricFunc SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes the cosine similarity between the labels and predictions. | /// Computes the cosine similarity between the labels and predictions. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -114,6 +145,15 @@ public interface IMetricsApi | |||||
| string name = "top_k_categorical_accuracy", | string name = "top_k_categorical_accuracy", | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT); | TF_DataType dtype = TF_DataType.TF_FLOAT); | ||||
| /// <summary> | |||||
| /// Computes how often integer targets are in the top K predictions. | |||||
| /// </summary> | |||||
| /// <param name="k"></param> | |||||
| /// <returns></returns> | |||||
| IMetricFunc SparseTopKCategoricalAccuracy(int k = 5, | |||||
| string name = "sparse_top_k_categorical_accuracy", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes the precision of the predictions with respect to the labels. | /// Computes the precision of the predictions with respect to the labels. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -276,6 +276,64 @@ namespace Tensorflow.Keras | |||||
| return -math_ops.reduce_sum(target * math_ops.log(output), new Axis(axis)); | return -math_ops.reduce_sum(target * math_ops.log(output), new Axis(axis)); | ||||
| } | } | ||||
| public Tensor sparse_categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1, int? ignore_class = null) | |||||
| { | |||||
| target = tf.cast(target, tf.int64); | |||||
| if (!from_logits) | |||||
| { | |||||
| var epsilon_ = constant_op.constant(epsilon(), output.dtype.as_base_dtype()); | |||||
| output = tf.clip_by_value(output, epsilon_, 1 - epsilon_); | |||||
| output = tf.math.log(output); | |||||
| } | |||||
| var output_rank = output.shape.ndim; | |||||
| if (output_rank > -1) | |||||
| { | |||||
| axis = Math.Abs(axis) % output_rank; | |||||
| if (axis != output_rank - 1) | |||||
| { | |||||
| /*var permutation = list( | |||||
| itertools.chain( | |||||
| range(axis), range(axis + 1, output_rank), [axis] | |||||
| ) | |||||
| ); | |||||
| output = tf.transpose(output, perm: permutation);*/ | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| var output_shape = tf.shape(output); | |||||
| var target_rank = target.shape.ndim; | |||||
| var update_shape = target_rank > -1 && output_rank > -1 && target_rank != output_rank - 1; | |||||
| if (update_shape) | |||||
| { | |||||
| /*var target = flatten(target); | |||||
| output = tf.reshape(output, [-1, output_shape[-1]]);*/ | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| if (ignore_class.HasValue) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| var res = tf.nn.sparse_softmax_cross_entropy_with_logits(labels: target, logits: output); | |||||
| if (ignore_class.HasValue) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| if (update_shape && output_rank >= 3) | |||||
| { | |||||
| // If our output includes timesteps or | |||||
| // spatial dimensions we need to reshape | |||||
| res = tf.reshape(res, output_shape[":-1"]); | |||||
| } | |||||
| return res; | |||||
| } | |||||
| public Tensor binary_crossentropy(Tensor target, Tensor output, bool from_logits = false) | public Tensor binary_crossentropy(Tensor target, Tensor output, bool from_logits = false) | ||||
| { | { | ||||
| if (from_logits) | if (from_logits) | ||||
| @@ -27,6 +27,11 @@ | |||||
| return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis); | return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis); | ||||
| } | } | ||||
| public Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, int? ignore_class = null, Axis? axis = null) | |||||
| { | |||||
| return keras.backend.sparse_categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis ?? -1, ignore_class: ignore_class); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Calculates how often predictions matches integer labels. | /// Calculates how often predictions matches integer labels. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -103,5 +108,14 @@ | |||||
| public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) | public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
| => new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype); | => new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype); | ||||
| public IMetricFunc SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, int? ignore_class = null, Axis? axis = null) | |||||
| => new SparseCategoricalCrossentropy(name: name, dtype: dtype, from_logits: from_logits, ignore_class: ignore_class, axis: axis ?? -1); | |||||
| public IMetricFunc SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| => new SparseTopKCategoricalAccuracy(k: k, name: name, dtype: dtype); | |||||
| public IMetricFunc SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| => new SparseCategoricalAccuracy(name: name, dtype: dtype); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,11 @@ | |||||
| namespace Tensorflow.Keras.Metrics; | |||||
| public class SparseCategoricalAccuracy : MeanMetricWrapper | |||||
| { | |||||
| public SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| : base((yt, yp) => metrics_utils.sparse_categorical_matches(yt, yp), | |||||
| name: name, | |||||
| dtype: dtype) | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,16 @@ | |||||
| namespace Tensorflow.Keras.Metrics; | |||||
| public class SparseCategoricalCrossentropy : MeanMetricWrapper | |||||
| { | |||||
| public SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| bool from_logits = false, | |||||
| int? ignore_class = null, | |||||
| Axis? axis = null) | |||||
| : base((yt, yp) => keras.metrics.sparse_categorical_crossentropy( | |||||
| yt, yp, from_logits: from_logits, ignore_class: ignore_class, axis: axis ?? -1), | |||||
| name: name, | |||||
| dtype: dtype) | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,11 @@ | |||||
| namespace Tensorflow.Keras.Metrics; | |||||
| public class SparseTopKCategoricalAccuracy : MeanMetricWrapper | |||||
| { | |||||
| public SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| : base((yt, yp) => metrics_utils.sparse_top_k_categorical_matches(yt, yp, k), | |||||
| name: name, | |||||
| dtype: dtype) | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -73,7 +73,7 @@ public class metrics_utils | |||||
| y_true = tf.squeeze(y_true, new Shape(-1)); | y_true = tf.squeeze(y_true, new Shape(-1)); | ||||
| } | } | ||||
| y_pred = tf.math.argmax(y_pred, axis: -1); | y_pred = tf.math.argmax(y_pred, axis: -1); | ||||
| y_pred = tf.cast(y_pred, y_true.dtype); | |||||
| var matches = tf.cast( | var matches = tf.cast( | ||||
| tf.equal(y_true, y_pred), | tf.equal(y_true, y_pred), | ||||
| dtype: keras.backend.floatx() | dtype: keras.backend.floatx() | ||||
| @@ -74,6 +74,26 @@ public class MetricsTest : EagerModeTestBase | |||||
| Assert.AreEqual(r, 0.3f); | Assert.AreEqual(r, 0.3f); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void SparseCategoricalAccuracy() | |||||
| { | |||||
| var y_true = np.array(new[] { 2, 1 }); | |||||
| var y_pred = np.array(new[,] { { 0.1f, 0.6f, 0.3f }, { 0.05f, 0.95f, 0f } }); | |||||
| var m = tf.keras.metrics.SparseCategoricalAccuracy(); | |||||
| m.update_state(y_true, y_pred); | |||||
| var r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.5f); | |||||
| m.reset_states(); | |||||
| var weights = np.array(new[] { 0.7f, 0.3f }); | |||||
| m.update_state(y_true, y_pred, sample_weight: weights); | |||||
| r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.3f); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy | /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy | ||||
| /// </summary> | /// </summary> | ||||
| @@ -94,6 +114,20 @@ public class MetricsTest : EagerModeTestBase | |||||
| Assert.AreEqual(r, 1.6271976f); | Assert.AreEqual(r, 1.6271976f); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalCrossentropy | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void SparseCategoricalCrossentropy() | |||||
| { | |||||
| var y_true = np.array(new[] { 1, 2 }); | |||||
| var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } }); | |||||
| var m = tf.keras.metrics.SparseCategoricalCrossentropy(); | |||||
| m.update_state(y_true, y_pred); | |||||
| var r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 1.1769392f); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CosineSimilarity | /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CosineSimilarity | ||||
| /// </summary> | /// </summary> | ||||
| @@ -207,6 +241,26 @@ public class MetricsTest : EagerModeTestBase | |||||
| Assert.AreEqual(r, 0.3f); | Assert.AreEqual(r, 0.3f); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseTopKCategoricalAccuracy | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void SparseTopKCategoricalAccuracy() | |||||
| { | |||||
| var y_true = np.array(new[] { 2, 1 }); | |||||
| var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); | |||||
| var m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k: 1); | |||||
| m.update_state(y_true, y_pred); | |||||
| var r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.5f); | |||||
| m.reset_states(); | |||||
| var weights = np.array(new[] { 0.7f, 0.3f }); | |||||
| m.update_state(y_true, y_pred, sample_weight: weights); | |||||
| r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.3f); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy | /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy | ||||
| /// </summary> | /// </summary> | ||||