Browse Source

np.sort

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
56a64dae2e
3 changed files with 56 additions and 3 deletions
  1. +20
    -2
      src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs
  2. +25
    -0
      src/TensorFlowNET.Core/Operations/sort_ops.cs
  3. +11
    -1
      test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs

+ 20
- 2
src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Globalization;
using System.Numerics;
using System.Text;

@@ -9,11 +10,11 @@ namespace Tensorflow.NumPy
public partial class np
{
[AutoNumPy]
public static NDArray argmax(NDArray a, Axis axis = null)
public static NDArray argmax(NDArray a, Axis? axis = null)
=> new NDArray(math_ops.argmax(a, axis ?? 0));

[AutoNumPy]
public static NDArray argsort(NDArray a, Axis axis = null)
public static NDArray argsort(NDArray a, Axis? axis = null)
=> new NDArray(sort_ops.argsort(a, axis: axis ?? -1));

[AutoNumPy]
@@ -25,5 +26,22 @@ namespace Tensorflow.NumPy

[AutoNumPy]
public static void shuffle(NDArray x) => np.random.shuffle(x);

/// <summary>
/// Sorts a ndarray
/// </summary>
/// <param name="values"></param>
/// <param name="axis">
/// The axis along which to sort. The default is -1, which sorts the last axis.
/// </param>
/// <param name="direction">
/// The direction in which to sort the values (`'ASCENDING'` or `'DESCENDING'`)
/// </param>
/// <returns>
/// A `NDArray` with the same dtype and shape as `values`, with the elements sorted along the given `axis`.
/// </returns>
[AutoNumPy]
public static NDArray sort(NDArray values, Axis? axis = null, string direction = "ASCENDING")
=> new NDArray(sort_ops.sort(values, axis: axis ?? -1, direction: direction));
}
}

+ 25
- 0
src/TensorFlowNET.Core/Operations/sort_ops.cs View File

@@ -47,6 +47,31 @@ namespace Tensorflow
return indices;
}

public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null)
{
var k = array_ops.shape(values)[axis];
values = -values;
var static_rank = values.shape.ndim;
var top_k_input = values;
if (axis == -1 || axis + 1 == values.shape.ndim)
{
}
else
{
if (axis == 0 && static_rank == 2)
top_k_input = array_ops.transpose(values, new[] { 1, 0 });
else
throw new NotImplementedException("");
}

(values, _) = tf.Context.ExecuteOp("TopKV2", name,
new ExecuteOpArgs(top_k_input, k).SetAttributes(new
{
sorted = true
}));
return -values;
}

public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null)
=> tf.Context.ExecuteOp("MatrixInverse", name,
new ExecuteOpArgs(input).SetAttributes(new


+ 11
- 1
test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs View File

@@ -5,7 +5,6 @@ using System.Linq;
using System.Text;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NumPy
{
@@ -30,5 +29,16 @@ namespace TensorFlowNET.UnitTest.NumPy
Assert.AreEqual(ind[0], new[] { 0, 1 });
Assert.AreEqual(ind[1], new[] { 1, 0 });
}

/// <summary>
/// https://numpy.org/doc/stable/reference/generated/numpy.sort.html
/// </summary>
[TestMethod]
public void sort()
{
var x = np.array(new int[] { 3, 1, 2 });
var sorted = np.sort(x);
Assert.IsTrue(sorted.ToArray<int>() is [1, 2, 3]);
}
}
}

Loading…
Cancel
Save