Browse Source

argsort fix.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
135562e5bf
2 changed files with 21 additions and 2 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/NumPy/Axis.cs
  2. +15
    -2
      src/TensorFlowNET.Core/Operations/sort_ops.cs

+ 6
- 0
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -55,6 +55,12 @@ namespace Tensorflow
public static implicit operator Tensor(Axis axis)
=> constant_op.constant(axis);

public static bool operator ==(Axis left, int right)
=> left.IsScalar && left[0] == right;

public static bool operator !=(Axis left, int right)
=> !(left == right);

public override string ToString()
=> IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})";
}


+ 15
- 2
src/TensorFlowNET.Core/Operations/sort_ops.cs View File

@@ -14,7 +14,7 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Operations;
using System;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -26,8 +26,21 @@ namespace Tensorflow
axis = axis ?? new Axis(-1);
var k = array_ops.shape(values)[axis];
values = -values;
var static_rank = values.shape.ndim;
var top_k_input = values;
if (axis == -1 || axis + 1 == values.shape.ndim)
{
}
else
{
if (axis == 0 && static_rank == 2)
top_k_input = array_ops.transpose(values, new[] { 1, 0 });
else
throw new NotImplementedException("");
}

var (_, indices) = tf.Context.ExecuteOp("TopKV2", name,
new ExecuteOpArgs(values, k).SetAttributes(new
new ExecuteOpArgs(top_k_input, k).SetAttributes(new
{
sorted = true
}));


Loading…
Cancel
Save