| @@ -2,6 +2,7 @@ | |||||
| public interface IMetricFunc | public interface IMetricFunc | ||||
| { | { | ||||
| string Name { get; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Accumulates metric statistics. | /// Accumulates metric statistics. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -5,6 +5,10 @@ public interface IMetricsApi | |||||
| Tensor binary_accuracy(Tensor y_true, Tensor y_pred); | Tensor binary_accuracy(Tensor y_true, Tensor y_pred); | ||||
| Tensor categorical_accuracy(Tensor y_true, Tensor y_pred); | Tensor categorical_accuracy(Tensor y_true, Tensor y_pred); | ||||
| Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, | |||||
| bool from_logits = false, | |||||
| float label_smoothing = 0f, | |||||
| Axis? axis = null); | |||||
| Tensor mean_absolute_error(Tensor y_true, Tensor y_pred); | Tensor mean_absolute_error(Tensor y_true, Tensor y_pred); | ||||
| @@ -27,14 +31,39 @@ public interface IMetricsApi | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5); | Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5); | ||||
| /// <summary> | |||||
| /// Calculates how often predictions match binary labels. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| IMetricFunc BinaryAccuracy(string name = "binary_accuracy", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| float threshold = 05f); | |||||
| /// <summary> | |||||
| /// Calculates how often predictions match one-hot labels. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| bool from_logits = false, | |||||
| float label_smoothing = 0f, | |||||
| Axis? axis = null); | |||||
| /// <summary> | |||||
| /// Computes the crossentropy metric between the labels and predictions. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes how often targets are in the top K predictions. | /// Computes how often targets are in the top K predictions. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="y_true"></param> | |||||
| /// <param name="y_pred"></param> | |||||
| /// <param name="k"></param> | /// <param name="k"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| IMetricFunc TopKCategoricalAccuracy(int k = 5, | |||||
| string name = "top_k_categorical_accuracy", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes the precision of the predictions with respect to the labels. | /// Computes the precision of the predictions with respect to the labels. | ||||
| @@ -45,7 +74,11 @@ public interface IMetricsApi | |||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <param name="dtype"></param> | /// <param name="dtype"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| IMetricFunc Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| IMetricFunc Precision(float thresholds = 0.5f, | |||||
| int top_k = 0, | |||||
| int class_id = 0, | |||||
| string name = "recall", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes the recall of the predictions with respect to the labels. | /// Computes the recall of the predictions with respect to the labels. | ||||
| @@ -56,5 +89,9 @@ public interface IMetricsApi | |||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <param name="dtype"></param> | /// <param name="dtype"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| IMetricFunc Recall(float thresholds = 0.5f, | |||||
| int top_k = 0, | |||||
| int class_id = 0, | |||||
| string name = "recall", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
| } | } | ||||
| @@ -9,15 +9,21 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| public class MetricsContainer : Container | public class MetricsContainer : Container | ||||
| { | { | ||||
| string[] _user_metrics; | |||||
| string[] _metric_names; | |||||
| Metric[] _metrics; | |||||
| List<Metric> _metrics_in_order; | |||||
| IMetricFunc[] _user_metrics = new IMetricFunc[0]; | |||||
| string[] _metric_names = new string[0]; | |||||
| Metric[] _metrics = new Metric[0]; | |||||
| List<IMetricFunc> _metrics_in_order = new List<IMetricFunc>(); | |||||
| public MetricsContainer(string[] metrics, string[] output_names = null) | |||||
| public MetricsContainer(IMetricFunc[] metrics, string[] output_names = null) | |||||
| : base(output_names) | : base(output_names) | ||||
| { | { | ||||
| _user_metrics = metrics; | _user_metrics = metrics; | ||||
| _built = false; | |||||
| } | |||||
| public MetricsContainer(string[] metrics, string[] output_names = null) | |||||
| : base(output_names) | |||||
| { | |||||
| _metric_names = metrics; | _metric_names = metrics; | ||||
| _built = false; | _built = false; | ||||
| } | } | ||||
| @@ -46,9 +52,11 @@ namespace Tensorflow.Keras.Engine | |||||
| void _create_ordered_metrics() | void _create_ordered_metrics() | ||||
| { | { | ||||
| _metrics_in_order = new List<Metric>(); | |||||
| foreach (var m in _metrics) | foreach (var m in _metrics) | ||||
| _metrics_in_order.append(m); | _metrics_in_order.append(m); | ||||
| foreach(var m in _user_metrics) | |||||
| _metrics_in_order.append(m); | |||||
| } | } | ||||
| Metric[] _get_metric_objects(string[] metrics, Tensor y_t, Tensor y_p) | Metric[] _get_metric_objects(string[] metrics, Tensor y_t, Tensor y_p) | ||||
| @@ -56,7 +64,7 @@ namespace Tensorflow.Keras.Engine | |||||
| return metrics.Select(x => _get_metric_object(x, y_t, y_p)).ToArray(); | return metrics.Select(x => _get_metric_object(x, y_t, y_p)).ToArray(); | ||||
| } | } | ||||
| Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p) | |||||
| public Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p) | |||||
| { | { | ||||
| Func<Tensor, Tensor, Tensor> metric_obj = null; | Func<Tensor, Tensor, Tensor> metric_obj = null; | ||||
| if (metric == "accuracy" || metric == "acc") | if (metric == "accuracy" || metric == "acc") | ||||
| @@ -94,7 +102,7 @@ namespace Tensorflow.Keras.Engine | |||||
| return new MeanMetricWrapper(metric_obj, metric); | return new MeanMetricWrapper(metric_obj, metric); | ||||
| } | } | ||||
| public IEnumerable<Metric> metrics | |||||
| public IEnumerable<IMetricFunc> metrics | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
| using Tensorflow.Keras.Metrics; | |||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| @@ -31,6 +32,27 @@ namespace Tensorflow.Keras.Engine | |||||
| _is_compiled = true; | _is_compiled = true; | ||||
| } | } | ||||
| public void compile(OptimizerV2 optimizer = null, | |||||
| ILossFunc loss = null, | |||||
| IMetricFunc[] metrics = null) | |||||
| { | |||||
| this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs | |||||
| { | |||||
| }); | |||||
| this.loss = loss ?? new MeanSquaredError(); | |||||
| compiled_loss = new LossesContainer(loss, output_names: output_names); | |||||
| compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||||
| int experimental_steps_per_execution = 1; | |||||
| _configure_steps_per_execution(experimental_steps_per_execution); | |||||
| // Initialize cache attrs. | |||||
| _reset_compile_cache(); | |||||
| _is_compiled = true; | |||||
| } | |||||
| public void compile(string optimizer, string loss, string[] metrics) | public void compile(string optimizer, string loss, string[] metrics) | ||||
| { | { | ||||
| var _optimizer = optimizer switch | var _optimizer = optimizer switch | ||||
| @@ -5,11 +5,11 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| public partial class Model | public partial class Model | ||||
| { | { | ||||
| public IEnumerable<Metric> metrics | |||||
| public IEnumerable<IMetricFunc> metrics | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| var _metrics = new List<Metric>(); | |||||
| var _metrics = new List<IMetricFunc>(); | |||||
| if (_is_compiled) | if (_is_compiled) | ||||
| { | { | ||||
| @@ -0,0 +1,11 @@ | |||||
| namespace Tensorflow.Keras.Metrics; | |||||
| public class BinaryAccuracy : MeanMetricWrapper | |||||
| { | |||||
| public BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 0.5f) | |||||
| : base((yt, yp) => metrics_utils.binary_matches(yt, yp), | |||||
| name: name, | |||||
| dtype: dtype) | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,12 @@ | |||||
| namespace Tensorflow.Keras.Metrics; | |||||
| public class CategoricalAccuracy : MeanMetricWrapper | |||||
| { | |||||
| public CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| : base((yt, yp) => metrics_utils.sparse_categorical_matches( | |||||
| tf.math.argmax(yt, axis: -1), yp), | |||||
| name: name, | |||||
| dtype: dtype) | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,16 @@ | |||||
| namespace Tensorflow.Keras.Metrics; | |||||
| public class CategoricalCrossentropy : MeanMetricWrapper | |||||
| { | |||||
| public CategoricalCrossentropy(string name = "categorical_crossentropy", | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| bool from_logits = false, | |||||
| float label_smoothing = 0f, | |||||
| Axis? axis = null) | |||||
| : base((yt, yp) => keras.metrics.categorical_crossentropy( | |||||
| yt, yp, from_logits: from_logits, label_smoothing: label_smoothing, axis: axis ?? -1), | |||||
| name: name, | |||||
| dtype: dtype) | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -15,6 +15,18 @@ | |||||
| return math_ops.cast(eql, TF_DataType.TF_FLOAT); | return math_ops.cast(eql, TF_DataType.TF_FLOAT); | ||||
| } | } | ||||
| public Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0, Axis? axis = null) | |||||
| { | |||||
| y_true = tf.cast(y_true, y_pred.dtype); | |||||
| // var label_smoothing_tensor = tf.convert_to_tensor(label_smoothing, dtype: y_pred.dtype); | |||||
| if (label_smoothing > 0) | |||||
| { | |||||
| var num_classes = tf.cast(tf.shape(y_true)[-1], y_pred.dtype); | |||||
| y_true = y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes); | |||||
| } | |||||
| return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Calculates how often predictions matches integer labels. | /// Calculates how often predictions matches integer labels. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -59,6 +71,15 @@ | |||||
| ); | ); | ||||
| } | } | ||||
| public IMetricFunc BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 5) | |||||
| => new BinaryAccuracy(); | |||||
| public IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| => new CategoricalAccuracy(name: name, dtype: dtype); | |||||
| public IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, float label_smoothing = 0, Axis? axis = null) | |||||
| => new CategoricalCrossentropy(); | |||||
| public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
| => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); | => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); | ||||
| @@ -1,10 +1,48 @@ | |||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.Metrics; | namespace Tensorflow.Keras.Metrics; | ||||
| public class metrics_utils | public class metrics_utils | ||||
| { | { | ||||
| public static Tensor binary_matches(Tensor y_true, Tensor y_pred, float threshold = 0.5f) | |||||
| { | |||||
| y_pred = tf.cast(y_pred > threshold, y_pred.dtype); | |||||
| return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx()); | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch. | |||||
| /// </summary> | |||||
| /// <param name="y_true"></param> | |||||
| /// <param name="y_pred"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor sparse_categorical_matches(Tensor y_true, Tensor y_pred) | |||||
| { | |||||
| var reshape_matches = false; | |||||
| var y_true_rank = y_true.shape.ndim; | |||||
| var y_pred_rank = y_pred.shape.ndim; | |||||
| var y_true_org_shape = tf.shape(y_true); | |||||
| if (y_true_rank > -1 && y_pred_rank > -1 && y_true.ndim == y_pred.ndim ) | |||||
| { | |||||
| reshape_matches = true; | |||||
| y_true = tf.squeeze(y_true, new Shape(-1)); | |||||
| } | |||||
| y_pred = tf.math.argmax(y_pred, axis: -1); | |||||
| var matches = tf.cast( | |||||
| tf.equal(y_true, y_pred), | |||||
| dtype: keras.backend.floatx() | |||||
| ); | |||||
| if (reshape_matches) | |||||
| { | |||||
| return tf.reshape(matches, shape: y_true_org_shape); | |||||
| } | |||||
| return matches; | |||||
| } | |||||
| public static Tensor sparse_top_k_categorical_matches(Tensor y_true, Tensor y_pred, int k = 5) | public static Tensor sparse_top_k_categorical_matches(Tensor y_true, Tensor y_pred, int k = 5) | ||||
| { | { | ||||
| var reshape_matches = false; | var reshape_matches = false; | ||||
| @@ -75,10 +75,8 @@ namespace Tensorflow.Keras.Utils | |||||
| { | { | ||||
| sample_weight = tf.expand_dims(sample_weight, -1); | sample_weight = tf.expand_dims(sample_weight, -1); | ||||
| } | } | ||||
| else | |||||
| { | |||||
| return (y_pred, y_true, sample_weight); | |||||
| } | |||||
| return (y_pred, y_true, sample_weight); | |||||
| } | } | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| @@ -14,6 +14,66 @@ namespace TensorFlowNET.Keras.UnitTest; | |||||
| [TestClass] | [TestClass] | ||||
| public class MetricsTest : EagerModeTestBase | public class MetricsTest : EagerModeTestBase | ||||
| { | { | ||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/BinaryAccuracy | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void BinaryAccuracy() | |||||
| { | |||||
| var y_true = np.array(new[,] { { 1 }, { 1 },{ 0 }, { 0 } }); | |||||
| var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } }); | |||||
| var m = tf.keras.metrics.BinaryAccuracy(); | |||||
| /*m.update_state(y_true, y_pred); | |||||
| var r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.75f); | |||||
| m.reset_states();*/ | |||||
| var weights = np.array(new[] { 1f, 0f, 0f, 1f }); | |||||
| m.update_state(y_true, y_pred, sample_weight: weights); | |||||
| var r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.5f); | |||||
| } | |||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalAccuracy | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void CategoricalAccuracy() | |||||
| { | |||||
| var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } }); | |||||
| var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); | |||||
| var m = tf.keras.metrics.CategoricalAccuracy(); | |||||
| m.update_state(y_true, y_pred); | |||||
| var r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.5f); | |||||
| m.reset_states(); | |||||
| var weights = np.array(new[] { 0.7f, 0.3f }); | |||||
| m.update_state(y_true, y_pred, sample_weight: weights); | |||||
| r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 0.3f); | |||||
| } | |||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void CategoricalCrossentropy() | |||||
| { | |||||
| var y_true = np.array(new[,] { { 0, 1, 0 }, { 0, 0, 1 } }); | |||||
| var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } }); | |||||
| var m = tf.keras.metrics.CategoricalCrossentropy(); | |||||
| m.update_state(y_true, y_pred); | |||||
| var r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 1.1769392f); | |||||
| m.reset_states(); | |||||
| var weights = np.array(new[] { 0.3f, 0.7f }); | |||||
| m.update_state(y_true, y_pred, sample_weight: weights); | |||||
| r = m.result().numpy(); | |||||
| Assert.AreEqual(r, 1.6271976f); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy | /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy | ||||
| /// </summary> | /// </summary> | ||||