diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs
index 1867d637..930afa0b 100644
--- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs
+++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs
@@ -2,6 +2,7 @@
public interface IMetricFunc
{
+ string Name { get; }
///
/// Accumulates metric statistics.
///
diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
index e27c198d..75946303 100644
--- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
@@ -5,6 +5,10 @@ public interface IMetricsApi
Tensor binary_accuracy(Tensor y_true, Tensor y_pred);
Tensor categorical_accuracy(Tensor y_true, Tensor y_pred);
+ Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred,
+ bool from_logits = false,
+ float label_smoothing = 0f,
+ Axis? axis = null);
Tensor mean_absolute_error(Tensor y_true, Tensor y_pred);
@@ -27,14 +31,39 @@ public interface IMetricsApi
///
Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5);
+ ///
+ /// Calculates how often predictions match binary labels.
+ ///
+ ///
+ IMetricFunc BinaryAccuracy(string name = "binary_accuracy",
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ float threshold = 05f);
+
+ ///
+ /// Calculates how often predictions match one-hot labels.
+ ///
+ ///
+ IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy",
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ bool from_logits = false,
+ float label_smoothing = 0f,
+ Axis? axis = null);
+
+ ///
+ /// Computes the crossentropy metric between the labels and predictions.
+ ///
+ ///
+ IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy",
+ TF_DataType dtype = TF_DataType.TF_FLOAT);
+
///
/// Computes how often targets are in the top K predictions.
///
- ///
- ///
///
///
- IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT);
+ IMetricFunc TopKCategoricalAccuracy(int k = 5,
+ string name = "top_k_categorical_accuracy",
+ TF_DataType dtype = TF_DataType.TF_FLOAT);
///
/// Computes the precision of the predictions with respect to the labels.
@@ -45,7 +74,11 @@ public interface IMetricsApi
///
///
///
- IMetricFunc Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
+ IMetricFunc Precision(float thresholds = 0.5f,
+ int top_k = 0,
+ int class_id = 0,
+ string name = "recall",
+ TF_DataType dtype = TF_DataType.TF_FLOAT);
///
/// Computes the recall of the predictions with respect to the labels.
@@ -56,5 +89,9 @@ public interface IMetricsApi
///
///
///
- IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
+ IMetricFunc Recall(float thresholds = 0.5f,
+ int top_k = 0,
+ int class_id = 0,
+ string name = "recall",
+ TF_DataType dtype = TF_DataType.TF_FLOAT);
}
diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
index 5eb05eaa..ee638410 100644
--- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
+++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
@@ -9,15 +9,21 @@ namespace Tensorflow.Keras.Engine
{
public class MetricsContainer : Container
{
- string[] _user_metrics;
- string[] _metric_names;
- Metric[] _metrics;
- List _metrics_in_order;
+ IMetricFunc[] _user_metrics = new IMetricFunc[0];
+ string[] _metric_names = new string[0];
+ Metric[] _metrics = new Metric[0];
+ List _metrics_in_order = new List();
- public MetricsContainer(string[] metrics, string[] output_names = null)
+ public MetricsContainer(IMetricFunc[] metrics, string[] output_names = null)
: base(output_names)
{
_user_metrics = metrics;
+ _built = false;
+ }
+
+ public MetricsContainer(string[] metrics, string[] output_names = null)
+ : base(output_names)
+ {
_metric_names = metrics;
_built = false;
}
@@ -46,9 +52,11 @@ namespace Tensorflow.Keras.Engine
void _create_ordered_metrics()
{
- _metrics_in_order = new List();
foreach (var m in _metrics)
_metrics_in_order.append(m);
+
+ foreach(var m in _user_metrics)
+ _metrics_in_order.append(m);
}
Metric[] _get_metric_objects(string[] metrics, Tensor y_t, Tensor y_p)
@@ -56,7 +64,7 @@ namespace Tensorflow.Keras.Engine
return metrics.Select(x => _get_metric_object(x, y_t, y_p)).ToArray();
}
- Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
+ public Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
{
Func metric_obj = null;
if (metric == "accuracy" || metric == "acc")
@@ -94,7 +102,7 @@ namespace Tensorflow.Keras.Engine
return new MeanMetricWrapper(metric_obj, metric);
}
- public IEnumerable metrics
+ public IEnumerable metrics
{
get
{
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs
index 7b051f1d..3d99129b 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs
@@ -1,6 +1,7 @@
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Losses;
+using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers;
namespace Tensorflow.Keras.Engine
@@ -31,6 +32,27 @@ namespace Tensorflow.Keras.Engine
_is_compiled = true;
}
+ public void compile(OptimizerV2 optimizer = null,
+ ILossFunc loss = null,
+ IMetricFunc[] metrics = null)
+ {
+ this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs
+ {
+ });
+
+ this.loss = loss ?? new MeanSquaredError();
+
+ compiled_loss = new LossesContainer(loss, output_names: output_names);
+ compiled_metrics = new MetricsContainer(metrics, output_names: output_names);
+
+ int experimental_steps_per_execution = 1;
+ _configure_steps_per_execution(experimental_steps_per_execution);
+
+ // Initialize cache attrs.
+ _reset_compile_cache();
+ _is_compiled = true;
+ }
+
public void compile(string optimizer, string loss, string[] metrics)
{
var _optimizer = optimizer switch
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs b/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs
index 214b9934..0e33b14e 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs
@@ -5,11 +5,11 @@ namespace Tensorflow.Keras.Engine
{
public partial class Model
{
- public IEnumerable metrics
+ public IEnumerable metrics
{
get
{
- var _metrics = new List();
+ var _metrics = new List();
if (_is_compiled)
{
diff --git a/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs
new file mode 100644
index 00000000..2977588e
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs
@@ -0,0 +1,11 @@
+namespace Tensorflow.Keras.Metrics;
+
+public class BinaryAccuracy : MeanMetricWrapper
+{
+ public BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 0.5f)
+ : base((yt, yp) => metrics_utils.binary_matches(yt, yp),
+ name: name,
+ dtype: dtype)
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs
new file mode 100644
index 00000000..d15cf26c
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs
@@ -0,0 +1,12 @@
+namespace Tensorflow.Keras.Metrics;
+
+public class CategoricalAccuracy : MeanMetricWrapper
+{
+ public CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ : base((yt, yp) => metrics_utils.sparse_categorical_matches(
+ tf.math.argmax(yt, axis: -1), yp),
+ name: name,
+ dtype: dtype)
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs
new file mode 100644
index 00000000..95720c41
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs
@@ -0,0 +1,16 @@
+namespace Tensorflow.Keras.Metrics;
+
+public class CategoricalCrossentropy : MeanMetricWrapper
+{
+ public CategoricalCrossentropy(string name = "categorical_crossentropy",
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ bool from_logits = false,
+ float label_smoothing = 0f,
+ Axis? axis = null)
+ : base((yt, yp) => keras.metrics.categorical_crossentropy(
+ yt, yp, from_logits: from_logits, label_smoothing: label_smoothing, axis: axis ?? -1),
+ name: name,
+ dtype: dtype)
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
index 74db1666..fcd0516b 100644
--- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
@@ -15,6 +15,18 @@
return math_ops.cast(eql, TF_DataType.TF_FLOAT);
}
+ public Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0, Axis? axis = null)
+ {
+ y_true = tf.cast(y_true, y_pred.dtype);
+ // var label_smoothing_tensor = tf.convert_to_tensor(label_smoothing, dtype: y_pred.dtype);
+ if (label_smoothing > 0)
+ {
+ var num_classes = tf.cast(tf.shape(y_true)[-1], y_pred.dtype);
+ y_true = y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes);
+ }
+ return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis);
+ }
+
///
/// Calculates how often predictions matches integer labels.
///
@@ -59,6 +71,15 @@
);
}
+ public IMetricFunc BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 5)
+ => new BinaryAccuracy();
+
+ public IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ => new CategoricalAccuracy(name: name, dtype: dtype);
+
+ public IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, float label_smoothing = 0, Axis? axis = null)
+ => new CategoricalCrossentropy();
+
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
index 0251462e..0f523e7e 100644
--- a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
+++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
@@ -1,10 +1,48 @@
using Tensorflow.Keras.Utils;
-using Tensorflow.NumPy;
namespace Tensorflow.Keras.Metrics;
public class metrics_utils
{
+ public static Tensor binary_matches(Tensor y_true, Tensor y_pred, float threshold = 0.5f)
+ {
+ y_pred = tf.cast(y_pred > threshold, y_pred.dtype);
+ return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx());
+ }
+
+ ///
+ /// Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch.
+ ///
+ ///
+ ///
+ ///
+ public static Tensor sparse_categorical_matches(Tensor y_true, Tensor y_pred)
+ {
+ var reshape_matches = false;
+ var y_true_rank = y_true.shape.ndim;
+ var y_pred_rank = y_pred.shape.ndim;
+ var y_true_org_shape = tf.shape(y_true);
+
+ if (y_true_rank > -1 && y_pred_rank > -1 && y_true.ndim == y_pred.ndim )
+ {
+ reshape_matches = true;
+ y_true = tf.squeeze(y_true, new Shape(-1));
+ }
+ y_pred = tf.math.argmax(y_pred, axis: -1);
+
+ var matches = tf.cast(
+ tf.equal(y_true, y_pred),
+ dtype: keras.backend.floatx()
+ );
+
+ if (reshape_matches)
+ {
+ return tf.reshape(matches, shape: y_true_org_shape);
+ }
+
+ return matches;
+ }
+
public static Tensor sparse_top_k_categorical_matches(Tensor y_true, Tensor y_pred, int k = 5)
{
var reshape_matches = false;
diff --git a/src/TensorFlowNET.Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Keras/Utils/losses_utils.cs
index 717acf5e..9ba40ca0 100644
--- a/src/TensorFlowNET.Keras/Utils/losses_utils.cs
+++ b/src/TensorFlowNET.Keras/Utils/losses_utils.cs
@@ -75,10 +75,8 @@ namespace Tensorflow.Keras.Utils
{
sample_weight = tf.expand_dims(sample_weight, -1);
}
- else
- {
- return (y_pred, y_true, sample_weight);
- }
+
+ return (y_pred, y_true, sample_weight);
}
throw new NotImplementedException("");
diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
index f3ba2e93..9389af96 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
@@ -14,6 +14,66 @@ namespace TensorFlowNET.Keras.UnitTest;
[TestClass]
public class MetricsTest : EagerModeTestBase
{
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/BinaryAccuracy
+ ///
+ [TestMethod]
+ public void BinaryAccuracy()
+ {
+ var y_true = np.array(new[,] { { 1 }, { 1 },{ 0 }, { 0 } });
+ var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } });
+ var m = tf.keras.metrics.BinaryAccuracy();
+ /*m.update_state(y_true, y_pred);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 0.75f);
+
+ m.reset_states();*/
+ var weights = np.array(new[] { 1f, 0f, 0f, 1f });
+ m.update_state(y_true, y_pred, sample_weight: weights);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 0.5f);
+ }
+
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalAccuracy
+ ///
+ [TestMethod]
+ public void CategoricalAccuracy()
+ {
+ var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
+ var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
+ var m = tf.keras.metrics.CategoricalAccuracy();
+ m.update_state(y_true, y_pred);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 0.5f);
+
+ m.reset_states();
+ var weights = np.array(new[] { 0.7f, 0.3f });
+ m.update_state(y_true, y_pred, sample_weight: weights);
+ r = m.result().numpy();
+ Assert.AreEqual(r, 0.3f);
+ }
+
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy
+ ///
+ [TestMethod]
+ public void CategoricalCrossentropy()
+ {
+ var y_true = np.array(new[,] { { 0, 1, 0 }, { 0, 0, 1 } });
+ var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } });
+ var m = tf.keras.metrics.CategoricalCrossentropy();
+ m.update_state(y_true, y_pred);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 1.1769392f);
+
+ m.reset_states();
+ var weights = np.array(new[] { 0.3f, 0.7f });
+ m.update_state(y_true, y_pred, sample_weight: weights);
+ r = m.result().numpy();
+ Assert.AreEqual(r, 1.6271976f);
+ }
+
///
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
///