diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index ce6dc4d6..7d3f6eff 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -39,6 +39,9 @@ namespace Tensorflow
public Tensor sum(Tensor x, Axis? axis = null, string name = null)
=> math_ops.reduce_sum(x, axis: axis, name: name);
+ public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK")
+ => nn_ops.in_top_k(predictions, targets, k, name);
+
///
///
///
diff --git a/src/TensorFlowNET.Core/Keras/IKerasApi.cs b/src/TensorFlowNET.Core/Keras/IKerasApi.cs
index 49ec9a5f..cffd3f79 100644
--- a/src/TensorFlowNET.Core/Keras/IKerasApi.cs
+++ b/src/TensorFlowNET.Core/Keras/IKerasApi.cs
@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses;
+using Tensorflow.Keras.Metrics;
namespace Tensorflow.Keras
{
@@ -10,6 +11,7 @@ namespace Tensorflow.Keras
{
public ILayersApi layers { get; }
public ILossesApi losses { get; }
+ public IMetricsApi metrics { get; }
public IInitializersApi initializers { get; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
new file mode 100644
index 00000000..2fe6d809
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
@@ -0,0 +1,29 @@
+namespace Tensorflow.Keras.Metrics;
+
+public interface IMetricsApi
+{
+ Tensor binary_accuracy(Tensor y_true, Tensor y_pred);
+
+ Tensor categorical_accuracy(Tensor y_true, Tensor y_pred);
+
+ Tensor mean_absolute_error(Tensor y_true, Tensor y_pred);
+
+ Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred);
+
+ ///
+ /// Calculates how often predictions matches integer labels.
+ ///
+ /// Integer ground truth values.
+ /// The prediction values.
+ /// Sparse categorical accuracy values.
+ Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred);
+
+ ///
+ /// Computes how often targets are in the top `K` predictions.
+ ///
+ ///
+ ///
+ ///
+ ///
+ Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5);
+}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index 0567858f..408d06eb 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -240,16 +240,8 @@ namespace Tensorflow.Operations
///
/// A `Tensor` of type `bool`.
public static Tensor in_top_kv2(Tensor predictions, Tensor targets, int k, string name = null)
- {
- var _op = tf.OpDefLib._apply_op_helper("InTopKV2", name: name, args: new
- {
- predictions,
- targets,
- k
- });
-
- return _op.output;
- }
+ => tf.Context.ExecuteOp("InTopKV2", name,
+ new ExecuteOpArgs(predictions, targets, k));
public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null)
=> tf.Context.ExecuteOp("LeakyRelu", name,
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index 5f09f202..7af89f13 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -121,6 +121,11 @@ namespace Tensorflow
if (dtype == TF_DataType.TF_INT32)
values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray();
}
+ else if (values is double[] double_values)
+ {
+ if (dtype == TF_DataType.TF_FLOAT)
+ values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray();
+ }
else
values = Convert.ChangeType(values, new_system_dtype);
diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs
index 4e0c612b..e0d148ce 100644
--- a/src/TensorFlowNET.Keras/KerasInterface.cs
+++ b/src/TensorFlowNET.Keras/KerasInterface.cs
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras
ThreadLocal _backend = new ThreadLocal(() => new BackendImpl());
public BackendImpl backend => _backend.Value;
public OptimizerApi optimizers { get; } = new OptimizerApi();
- public MetricsApi metrics { get; } = new MetricsApi();
+ public IMetricsApi metrics { get; } = new MetricsApi();
public ModelsApi models { get; } = new ModelsApi();
public KerasUtils utils { get; } = new KerasUtils();
diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
index 3d614e02..6b0e2d8a 100644
--- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
@@ -2,7 +2,7 @@
namespace Tensorflow.Keras.Metrics
{
- public class MetricsApi
+ public class MetricsApi : IMetricsApi
{
public Tensor binary_accuracy(Tensor y_true, Tensor y_pred)
{
@@ -53,5 +53,12 @@ namespace Tensorflow.Keras.Metrics
var diff = (y_true - y_pred) / math_ops.maximum(math_ops.abs(y_true), keras.backend.epsilon());
return 100f * keras.backend.mean(math_ops.abs(diff), axis: -1);
}
+
+ public Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5)
+ {
+ return metrics_utils.sparse_top_k_categorical_matches(
+ tf.math.argmax(y_true, axis: -1), y_pred, k
+ );
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
new file mode 100644
index 00000000..de6a8402
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
@@ -0,0 +1,39 @@
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.Metrics;
+
+public class metrics_utils
+{
+ public static Tensor sparse_top_k_categorical_matches(Tensor y_true, Tensor y_pred, int k = 5)
+ {
+ var reshape_matches = false;
+ var y_true_rank = y_true.shape.ndim;
+ var y_pred_rank = y_pred.shape.ndim;
+ var y_true_org_shape = tf.shape(y_true);
+
+ if (y_pred_rank > 2)
+ {
+ y_pred = tf.reshape(y_pred, (-1, y_pred.shape[-1]));
+ }
+
+ if (y_true_rank > 1)
+ {
+ reshape_matches = true;
+ y_true = tf.reshape(y_true, new Shape(-1));
+ }
+
+ var matches = tf.cast(
+ tf.math.in_top_k(
+ predictions: y_pred, targets: tf.cast(y_true, np.int32), k: k
+ ),
+ dtype: keras.backend.floatx()
+ );
+
+ if (reshape_matches)
+ {
+ return tf.reshape(matches, shape: y_true_org_shape);
+ }
+
+ return matches;
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
new file mode 100644
index 00000000..bb0107d4
--- /dev/null
+++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
@@ -0,0 +1,28 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Tensorflow;
+using Tensorflow.NumPy;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace TensorFlowNET.Keras.UnitTest;
+
+[TestClass]
+public class MetricsTest : EagerModeTestBase
+{
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy
+ ///
+ [TestMethod]
+ public void top_k_categorical_accuracy()
+ {
+ var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
+ var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
+ var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3);
+ Assert.AreEqual(m.numpy(), new[] { 1f, 1f });
+ }
+}