diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs
index 61feb5e7..5182d572 100644
--- a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs
+++ b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
+using System.Globalization;
using System.Numerics;
using System.Text;
@@ -9,11 +10,11 @@ namespace Tensorflow.NumPy
public partial class np
{
[AutoNumPy]
- public static NDArray argmax(NDArray a, Axis axis = null)
+ public static NDArray argmax(NDArray a, Axis? axis = null)
=> new NDArray(math_ops.argmax(a, axis ?? 0));
[AutoNumPy]
- public static NDArray argsort(NDArray a, Axis axis = null)
+ public static NDArray argsort(NDArray a, Axis? axis = null)
=> new NDArray(sort_ops.argsort(a, axis: axis ?? -1));
[AutoNumPy]
@@ -25,5 +26,22 @@ namespace Tensorflow.NumPy
[AutoNumPy]
public static void shuffle(NDArray x) => np.random.shuffle(x);
+
+ ///
+ /// Sorts a ndarray
+ ///
+ ///
+ ///
+ /// The axis along which to sort. The default is -1, which sorts the last axis.
+ ///
+ ///
+ /// The direction in which to sort the values (`'ASCENDING'` or `'DESCENDING'`)
+ ///
+ ///
+ /// A `NDArray` with the same dtype and shape as `values`, with the elements sorted along the given `axis`.
+ ///
+ [AutoNumPy]
+ public static NDArray sort(NDArray values, Axis? axis = null, string direction = "ASCENDING")
+ => new NDArray(sort_ops.sort(values, axis: axis ?? -1, direction: direction));
}
}
diff --git a/src/TensorFlowNET.Core/Operations/sort_ops.cs b/src/TensorFlowNET.Core/Operations/sort_ops.cs
index 1dcaf1f8..34b90323 100644
--- a/src/TensorFlowNET.Core/Operations/sort_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/sort_ops.cs
@@ -47,6 +47,31 @@ namespace Tensorflow
return indices;
}
+ public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null)
+ {
+ var k = array_ops.shape(values)[axis];
+ values = -values;
+ var static_rank = values.shape.ndim;
+ var top_k_input = values;
+ if (axis == -1 || axis + 1 == values.shape.ndim)
+ {
+ }
+ else
+ {
+ if (axis == 0 && static_rank == 2)
+ top_k_input = array_ops.transpose(values, new[] { 1, 0 });
+ else
+ throw new NotImplementedException("");
+ }
+
+ (values, _) = tf.Context.ExecuteOp("TopKV2", name,
+ new ExecuteOpArgs(top_k_input, k).SetAttributes(new
+ {
+ sorted = true
+ }));
+ return -values;
+ }
+
public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null)
=> tf.Context.ExecuteOp("MatrixInverse", name,
new ExecuteOpArgs(input).SetAttributes(new
diff --git a/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs
index 2a617d40..13a5d973 100644
--- a/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs
+++ b/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs
@@ -5,7 +5,6 @@ using System.Linq;
using System.Text;
using Tensorflow;
using Tensorflow.NumPy;
-using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.NumPy
{
@@ -30,5 +29,16 @@ namespace TensorFlowNET.UnitTest.NumPy
Assert.AreEqual(ind[0], new[] { 0, 1 });
Assert.AreEqual(ind[1], new[] { 1, 0 });
}
+
+ ///
+ /// https://numpy.org/doc/stable/reference/generated/numpy.sort.html
+ ///
+ [TestMethod]
+ public void sort()
+ {
+ var x = np.array(new int[] { 3, 1, 2 });
+ var sorted = np.sort(x);
+ Assert.IsTrue(sorted.ToArray() is [1, 2, 3]);
+ }
}
}