diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 05644640..6c7189df 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -55,6 +55,12 @@ namespace Tensorflow public static implicit operator Tensor(Axis axis) => constant_op.constant(axis); + public static bool operator ==(Axis left, int right) + => left.IsScalar && left[0] == right; + + public static bool operator !=(Axis left, int right) + => !(left == right); + public override string ToString() => IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; } diff --git a/src/TensorFlowNET.Core/Operations/sort_ops.cs b/src/TensorFlowNET.Core/Operations/sort_ops.cs index 314daefd..1dcaf1f8 100644 --- a/src/TensorFlowNET.Core/Operations/sort_ops.cs +++ b/src/TensorFlowNET.Core/Operations/sort_ops.cs @@ -14,7 +14,7 @@ limitations under the License. ******************************************************************************/ -using Tensorflow.Operations; +using System; using static Tensorflow.Binding; namespace Tensorflow @@ -26,8 +26,21 @@ namespace Tensorflow axis = axis ?? new Axis(-1); var k = array_ops.shape(values)[axis]; values = -values; + var static_rank = values.shape.ndim; + var top_k_input = values; + if (axis == -1 || axis + 1 == values.shape.ndim) + { + } + else + { + if (axis == 0 && static_rank == 2) + top_k_input = array_ops.transpose(values, new[] { 1, 0 }); + else + throw new NotImplementedException(""); + } + var (_, indices) = tf.Context.ExecuteOp("TopKV2", name, - new ExecuteOpArgs(values, k).SetAttributes(new + new ExecuteOpArgs(top_k_input, k).SetAttributes(new { sorted = true }));