| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | using System.Collections; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Globalization; | |||||
| using System.Numerics; | using System.Numerics; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -9,11 +10,11 @@ namespace Tensorflow.NumPy | |||||
| public partial class np | public partial class np | ||||
| { | { | ||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray argmax(NDArray a, Axis axis = null) | |||||
| public static NDArray argmax(NDArray a, Axis? axis = null) | |||||
| => new NDArray(math_ops.argmax(a, axis ?? 0)); | => new NDArray(math_ops.argmax(a, axis ?? 0)); | ||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray argsort(NDArray a, Axis axis = null) | |||||
| public static NDArray argsort(NDArray a, Axis? axis = null) | |||||
| => new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); | => new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); | ||||
| [AutoNumPy] | [AutoNumPy] | ||||
| @@ -25,5 +26,22 @@ namespace Tensorflow.NumPy | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static void shuffle(NDArray x) => np.random.shuffle(x); | public static void shuffle(NDArray x) => np.random.shuffle(x); | ||||
| /// <summary> | |||||
| /// Sorts a ndarray | |||||
| /// </summary> | |||||
| /// <param name="values"></param> | |||||
| /// <param name="axis"> | |||||
| /// The axis along which to sort. The default is -1, which sorts the last axis. | |||||
| /// </param> | |||||
| /// <param name="direction"> | |||||
| /// The direction in which to sort the values (`'ASCENDING'` or `'DESCENDING'`) | |||||
| /// </param> | |||||
| /// <returns> | |||||
| /// A `NDArray` with the same dtype and shape as `values`, with the elements sorted along the given `axis`. | |||||
| /// </returns> | |||||
| [AutoNumPy] | |||||
| public static NDArray sort(NDArray values, Axis? axis = null, string direction = "ASCENDING") | |||||
| => new NDArray(sort_ops.sort(values, axis: axis ?? -1, direction: direction)); | |||||
| } | } | ||||
| } | } | ||||
| @@ -47,6 +47,31 @@ namespace Tensorflow | |||||
| return indices; | return indices; | ||||
| } | } | ||||
| public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null) | |||||
| { | |||||
| var k = array_ops.shape(values)[axis]; | |||||
| values = -values; | |||||
| var static_rank = values.shape.ndim; | |||||
| var top_k_input = values; | |||||
| if (axis == -1 || axis + 1 == values.shape.ndim) | |||||
| { | |||||
| } | |||||
| else | |||||
| { | |||||
| if (axis == 0 && static_rank == 2) | |||||
| top_k_input = array_ops.transpose(values, new[] { 1, 0 }); | |||||
| else | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| (values, _) = tf.Context.ExecuteOp("TopKV2", name, | |||||
| new ExecuteOpArgs(top_k_input, k).SetAttributes(new | |||||
| { | |||||
| sorted = true | |||||
| })); | |||||
| return -values; | |||||
| } | |||||
| public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null) | public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null) | ||||
| => tf.Context.ExecuteOp("MatrixInverse", name, | => tf.Context.ExecuteOp("MatrixInverse", name, | ||||
| new ExecuteOpArgs(input).SetAttributes(new | new ExecuteOpArgs(input).SetAttributes(new | ||||
| @@ -5,7 +5,6 @@ using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.NumPy | namespace TensorFlowNET.UnitTest.NumPy | ||||
| { | { | ||||
| @@ -30,5 +29,16 @@ namespace TensorFlowNET.UnitTest.NumPy | |||||
| Assert.AreEqual(ind[0], new[] { 0, 1 }); | Assert.AreEqual(ind[0], new[] { 0, 1 }); | ||||
| Assert.AreEqual(ind[1], new[] { 1, 0 }); | Assert.AreEqual(ind[1], new[] { 1, 0 }); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// https://numpy.org/doc/stable/reference/generated/numpy.sort.html | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void sort() | |||||
| { | |||||
| var x = np.array(new int[] { 3, 1, 2 }); | |||||
| var sorted = np.sort(x); | |||||
| Assert.IsTrue(sorted.ToArray<int>() is [1, 2, 3]); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||