diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs index 1867d637..930afa0b 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs @@ -2,6 +2,7 @@ public interface IMetricFunc { + string Name { get; } /// /// Accumulates metric statistics. /// diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs index e27c198d..75946303 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs @@ -5,6 +5,10 @@ public interface IMetricsApi Tensor binary_accuracy(Tensor y_true, Tensor y_pred); Tensor categorical_accuracy(Tensor y_true, Tensor y_pred); + Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, + bool from_logits = false, + float label_smoothing = 0f, + Axis? axis = null); Tensor mean_absolute_error(Tensor y_true, Tensor y_pred); @@ -27,14 +31,39 @@ public interface IMetricsApi /// Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5); + /// + /// Calculates how often predictions match binary labels. + /// + /// + IMetricFunc BinaryAccuracy(string name = "binary_accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT, + float threshold = 05f); + + /// + /// Calculates how often predictions match one-hot labels. + /// + /// + IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", + TF_DataType dtype = TF_DataType.TF_FLOAT, + bool from_logits = false, + float label_smoothing = 0f, + Axis? axis = null); + + /// + /// Computes the crossentropy metric between the labels and predictions. + /// + /// + IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT); + /// /// Computes how often targets are in the top K predictions. /// - /// - /// /// /// - IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); + IMetricFunc TopKCategoricalAccuracy(int k = 5, + string name = "top_k_categorical_accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT); /// /// Computes the precision of the predictions with respect to the labels. @@ -45,7 +74,11 @@ public interface IMetricsApi /// /// /// - IMetricFunc Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT); + IMetricFunc Precision(float thresholds = 0.5f, + int top_k = 0, + int class_id = 0, + string name = "recall", + TF_DataType dtype = TF_DataType.TF_FLOAT); /// /// Computes the recall of the predictions with respect to the labels. @@ -56,5 +89,9 @@ public interface IMetricsApi /// /// /// - IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT); + IMetricFunc Recall(float thresholds = 0.5f, + int top_k = 0, + int class_id = 0, + string name = "recall", + TF_DataType dtype = TF_DataType.TF_FLOAT); } diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs index 5eb05eaa..ee638410 100644 --- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs @@ -9,15 +9,21 @@ namespace Tensorflow.Keras.Engine { public class MetricsContainer : Container { - string[] _user_metrics; - string[] _metric_names; - Metric[] _metrics; - List _metrics_in_order; + IMetricFunc[] _user_metrics = new IMetricFunc[0]; + string[] _metric_names = new string[0]; + Metric[] _metrics = new Metric[0]; + List _metrics_in_order = new List(); - public MetricsContainer(string[] metrics, string[] output_names = null) + public MetricsContainer(IMetricFunc[] metrics, string[] output_names = null) : base(output_names) { _user_metrics = metrics; + _built = false; + } + + public MetricsContainer(string[] metrics, string[] output_names = null) + : base(output_names) + { _metric_names = metrics; _built = false; } @@ -46,9 +52,11 @@ namespace Tensorflow.Keras.Engine void _create_ordered_metrics() { - _metrics_in_order = new List(); foreach (var m in _metrics) _metrics_in_order.append(m); + + foreach(var m in _user_metrics) + _metrics_in_order.append(m); } Metric[] _get_metric_objects(string[] metrics, Tensor y_t, Tensor y_p) @@ -56,7 +64,7 @@ namespace Tensorflow.Keras.Engine return metrics.Select(x => _get_metric_object(x, y_t, y_p)).ToArray(); } - Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p) + public Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p) { Func metric_obj = null; if (metric == "accuracy" || metric == "acc") @@ -94,7 +102,7 @@ namespace Tensorflow.Keras.Engine return new MeanMetricWrapper(metric_obj, metric); } - public IEnumerable metrics + public IEnumerable metrics { get { diff --git a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs index 7b051f1d..3d99129b 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs @@ -1,6 +1,7 @@ using System; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Optimizers; namespace Tensorflow.Keras.Engine @@ -31,6 +32,27 @@ namespace Tensorflow.Keras.Engine _is_compiled = true; } + public void compile(OptimizerV2 optimizer = null, + ILossFunc loss = null, + IMetricFunc[] metrics = null) + { + this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs + { + }); + + this.loss = loss ?? new MeanSquaredError(); + + compiled_loss = new LossesContainer(loss, output_names: output_names); + compiled_metrics = new MetricsContainer(metrics, output_names: output_names); + + int experimental_steps_per_execution = 1; + _configure_steps_per_execution(experimental_steps_per_execution); + + // Initialize cache attrs. + _reset_compile_cache(); + _is_compiled = true; + } + public void compile(string optimizer, string loss, string[] metrics) { var _optimizer = optimizer switch diff --git a/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs b/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs index 214b9934..0e33b14e 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs @@ -5,11 +5,11 @@ namespace Tensorflow.Keras.Engine { public partial class Model { - public IEnumerable metrics + public IEnumerable metrics { get { - var _metrics = new List(); + var _metrics = new List(); if (_is_compiled) { diff --git a/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs new file mode 100644 index 00000000..2977588e --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Metrics; + +public class BinaryAccuracy : MeanMetricWrapper +{ + public BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 0.5f) + : base((yt, yp) => metrics_utils.binary_matches(yt, yp), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs new file mode 100644 index 00000000..d15cf26c --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs @@ -0,0 +1,12 @@ +namespace Tensorflow.Keras.Metrics; + +public class CategoricalAccuracy : MeanMetricWrapper +{ + public CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base((yt, yp) => metrics_utils.sparse_categorical_matches( + tf.math.argmax(yt, axis: -1), yp), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs new file mode 100644 index 00000000..95720c41 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs @@ -0,0 +1,16 @@ +namespace Tensorflow.Keras.Metrics; + +public class CategoricalCrossentropy : MeanMetricWrapper +{ + public CategoricalCrossentropy(string name = "categorical_crossentropy", + TF_DataType dtype = TF_DataType.TF_FLOAT, + bool from_logits = false, + float label_smoothing = 0f, + Axis? axis = null) + : base((yt, yp) => keras.metrics.categorical_crossentropy( + yt, yp, from_logits: from_logits, label_smoothing: label_smoothing, axis: axis ?? -1), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index 74db1666..fcd0516b 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -15,6 +15,18 @@ return math_ops.cast(eql, TF_DataType.TF_FLOAT); } + public Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0, Axis? axis = null) + { + y_true = tf.cast(y_true, y_pred.dtype); + // var label_smoothing_tensor = tf.convert_to_tensor(label_smoothing, dtype: y_pred.dtype); + if (label_smoothing > 0) + { + var num_classes = tf.cast(tf.shape(y_true)[-1], y_pred.dtype); + y_true = y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes); + } + return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis); + } + /// /// Calculates how often predictions matches integer labels. /// @@ -59,6 +71,15 @@ ); } + public IMetricFunc BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 5) + => new BinaryAccuracy(); + + public IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new CategoricalAccuracy(name: name, dtype: dtype); + + public IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, float label_smoothing = 0, Axis? axis = null) + => new CategoricalCrossentropy(); + public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs index 0251462e..0f523e7e 100644 --- a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs +++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs @@ -1,10 +1,48 @@ using Tensorflow.Keras.Utils; -using Tensorflow.NumPy; namespace Tensorflow.Keras.Metrics; public class metrics_utils { + public static Tensor binary_matches(Tensor y_true, Tensor y_pred, float threshold = 0.5f) + { + y_pred = tf.cast(y_pred > threshold, y_pred.dtype); + return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx()); + } + + /// + /// Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch. + /// + /// + /// + /// + public static Tensor sparse_categorical_matches(Tensor y_true, Tensor y_pred) + { + var reshape_matches = false; + var y_true_rank = y_true.shape.ndim; + var y_pred_rank = y_pred.shape.ndim; + var y_true_org_shape = tf.shape(y_true); + + if (y_true_rank > -1 && y_pred_rank > -1 && y_true.ndim == y_pred.ndim ) + { + reshape_matches = true; + y_true = tf.squeeze(y_true, new Shape(-1)); + } + y_pred = tf.math.argmax(y_pred, axis: -1); + + var matches = tf.cast( + tf.equal(y_true, y_pred), + dtype: keras.backend.floatx() + ); + + if (reshape_matches) + { + return tf.reshape(matches, shape: y_true_org_shape); + } + + return matches; + } + public static Tensor sparse_top_k_categorical_matches(Tensor y_true, Tensor y_pred, int k = 5) { var reshape_matches = false; diff --git a/src/TensorFlowNET.Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Keras/Utils/losses_utils.cs index 717acf5e..9ba40ca0 100644 --- a/src/TensorFlowNET.Keras/Utils/losses_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/losses_utils.cs @@ -75,10 +75,8 @@ namespace Tensorflow.Keras.Utils { sample_weight = tf.expand_dims(sample_weight, -1); } - else - { - return (y_pred, y_true, sample_weight); - } + + return (y_pred, y_true, sample_weight); } throw new NotImplementedException(""); diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs index f3ba2e93..9389af96 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs @@ -14,6 +14,66 @@ namespace TensorFlowNET.Keras.UnitTest; [TestClass] public class MetricsTest : EagerModeTestBase { + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/BinaryAccuracy + /// + [TestMethod] + public void BinaryAccuracy() + { + var y_true = np.array(new[,] { { 1 }, { 1 },{ 0 }, { 0 } }); + var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } }); + var m = tf.keras.metrics.BinaryAccuracy(); + /*m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.75f); + + m.reset_states();*/ + var weights = np.array(new[] { 1f, 0f, 0f, 1f }); + m.update_state(y_true, y_pred, sample_weight: weights); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalAccuracy + /// + [TestMethod] + public void CategoricalAccuracy() + { + var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } }); + var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); + var m = tf.keras.metrics.CategoricalAccuracy(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + + m.reset_states(); + var weights = np.array(new[] { 0.7f, 0.3f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.3f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy + /// + [TestMethod] + public void CategoricalCrossentropy() + { + var y_true = np.array(new[,] { { 0, 1, 0 }, { 0, 0, 1 } }); + var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } }); + var m = tf.keras.metrics.CategoricalCrossentropy(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 1.1769392f); + + m.reset_states(); + var weights = np.array(new[] { 0.3f, 0.7f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 1.6271976f); + } + /// /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy ///