From fcd2cd6573ee1608c092ed9092769c67f0bb19b1 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 10:41:30 -0600 Subject: [PATCH] nn_ops.in_top_kv2 --- src/TensorFlowNET.Core/APIs/tf.nn.cs | 2 +- .../Operations/NnOps/gen_nn_ops.cs | 22 ++++++++++++++++++- src/TensorFlowNET.Core/Operations/nn_ops.cs | 8 +++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 5b5786d1..64d47acd 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -134,7 +134,7 @@ namespace Tensorflow => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK") - => gen_ops.in_top_k(predictions, targets, k, name); + => nn_ops.in_top_k(predictions, targets, k, name); public Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index f3a63d68..fbc68dbf 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -244,7 +244,27 @@ namespace Tensorflow.Operations logits }); - return _op.outputs[0]; + return _op.output; + } + + /// + /// Says whether the targets are in the top `K` predictions. + /// + /// + /// + /// + /// + /// A `Tensor` of type `bool`. + public static Tensor in_top_kv2(Tensor predictions, Tensor targets, int k, string name = null) + { + var _op = _op_def_lib._apply_op_helper("InTopKV2", name: name, args: new + { + predictions, + targets, + k + }); + + return _op.output; } public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index 7ae1f3a9..124fd72b 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -111,6 +111,14 @@ namespace Tensorflow return noise_shape; } + public static Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = null) + { + return tf_with(ops.name_scope(name, "in_top_k"), delegate + { + return gen_nn_ops.in_top_kv2(predictions, targets, k, name: name); + }); + } + public static Tensor log_softmax(Tensor logits, int axis = -1, string name = null) { return _softmax(logits, gen_nn_ops.log_softmax, axis, name);