diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs index 61feb5e7..5182d572 100644 --- a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs @@ -1,6 +1,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Globalization; using System.Numerics; using System.Text; @@ -9,11 +10,11 @@ namespace Tensorflow.NumPy public partial class np { [AutoNumPy] - public static NDArray argmax(NDArray a, Axis axis = null) + public static NDArray argmax(NDArray a, Axis? axis = null) => new NDArray(math_ops.argmax(a, axis ?? 0)); [AutoNumPy] - public static NDArray argsort(NDArray a, Axis axis = null) + public static NDArray argsort(NDArray a, Axis? axis = null) => new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); [AutoNumPy] @@ -25,5 +26,22 @@ namespace Tensorflow.NumPy [AutoNumPy] public static void shuffle(NDArray x) => np.random.shuffle(x); + + /// + /// Sorts a ndarray + /// + /// + /// + /// The axis along which to sort. The default is -1, which sorts the last axis. + /// + /// + /// The direction in which to sort the values (`'ASCENDING'` or `'DESCENDING'`) + /// + /// + /// A `NDArray` with the same dtype and shape as `values`, with the elements sorted along the given `axis`. + /// + [AutoNumPy] + public static NDArray sort(NDArray values, Axis? axis = null, string direction = "ASCENDING") + => new NDArray(sort_ops.sort(values, axis: axis ?? -1, direction: direction)); } } diff --git a/src/TensorFlowNET.Core/Operations/sort_ops.cs b/src/TensorFlowNET.Core/Operations/sort_ops.cs index 1dcaf1f8..34b90323 100644 --- a/src/TensorFlowNET.Core/Operations/sort_ops.cs +++ b/src/TensorFlowNET.Core/Operations/sort_ops.cs @@ -47,6 +47,31 @@ namespace Tensorflow return indices; } + public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null) + { + var k = array_ops.shape(values)[axis]; + values = -values; + var static_rank = values.shape.ndim; + var top_k_input = values; + if (axis == -1 || axis + 1 == values.shape.ndim) + { + } + else + { + if (axis == 0 && static_rank == 2) + top_k_input = array_ops.transpose(values, new[] { 1, 0 }); + else + throw new NotImplementedException(""); + } + + (values, _) = tf.Context.ExecuteOp("TopKV2", name, + new ExecuteOpArgs(top_k_input, k).SetAttributes(new + { + sorted = true + })); + return -values; + } + public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null) => tf.Context.ExecuteOp("MatrixInverse", name, new ExecuteOpArgs(input).SetAttributes(new diff --git a/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs index 2a617d40..13a5d973 100644 --- a/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs +++ b/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs @@ -5,7 +5,6 @@ using System.Linq; using System.Text; using Tensorflow; using Tensorflow.NumPy; -using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.NumPy { @@ -30,5 +29,16 @@ namespace TensorFlowNET.UnitTest.NumPy Assert.AreEqual(ind[0], new[] { 0, 1 }); Assert.AreEqual(ind[1], new[] { 1, 0 }); } + + /// + /// https://numpy.org/doc/stable/reference/generated/numpy.sort.html + /// + [TestMethod] + public void sort() + { + var x = np.array(new int[] { 3, 1, 2 }); + var sorted = np.sort(x); + Assert.IsTrue(sorted.ToArray() is [1, 2, 3]); + } } }