|
|
|
@@ -14,7 +14,7 @@ |
|
|
|
limitations under the License. |
|
|
|
******************************************************************************/ |
|
|
|
|
|
|
|
using Tensorflow.Operations; |
|
|
|
using System; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
|
@@ -26,8 +26,21 @@ namespace Tensorflow |
|
|
|
axis = axis ?? new Axis(-1); |
|
|
|
var k = array_ops.shape(values)[axis]; |
|
|
|
values = -values; |
|
|
|
var static_rank = values.shape.ndim; |
|
|
|
var top_k_input = values; |
|
|
|
if (axis == -1 || axis + 1 == values.shape.ndim) |
|
|
|
{ |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
if (axis == 0 && static_rank == 2) |
|
|
|
top_k_input = array_ops.transpose(values, new[] { 1, 0 }); |
|
|
|
else |
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
|
|
|
|
var (_, indices) = tf.Context.ExecuteOp("TopKV2", name, |
|
|
|
new ExecuteOpArgs(values, k).SetAttributes(new |
|
|
|
new ExecuteOpArgs(top_k_input, k).SetAttributes(new |
|
|
|
{ |
|
|
|
sorted = true |
|
|
|
})); |
|
|
|
|