From 9d5bb8f1e23863b6e9998e405c3e1fa9d1c36d06 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 22 Dec 2019 09:02:56 -0600 Subject: [PATCH] tf.nn.ctc_greedy_decoder #473 --- src/TensorFlowNET.Core/APIs/tf.graph.cs | 2 +- src/TensorFlowNET.Core/APIs/tf.nn.cs | 3 + .../GraphTransformation/GraphTransformer.cs | 31 +++++++++ src/TensorFlowNET.Core/Operations/ctc_ops.cs | 67 +++++++++++++++++++ .../Operations/gen_ctc_ops.cs | 38 +++++++++++ .../{gen_image_ops.py.cs => gen_image_ops.cs} | 0 .../{gen_io_ops.py.cs => gen_io_ops.cs} | 0 src/TensorFlowNet.Benchmarks/Benchmark.csproj | 3 +- 8 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs create mode 100644 src/TensorFlowNET.Core/Operations/ctc_ops.cs create mode 100644 src/TensorFlowNET.Core/Operations/gen_ctc_ops.cs rename src/TensorFlowNET.Core/Operations/{gen_image_ops.py.cs => gen_image_ops.cs} (100%) rename src/TensorFlowNET.Core/Operations/{gen_io_ops.py.cs => gen_io_ops.cs} (100%) diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs index a28c007a..05851b6b 100644 --- a/src/TensorFlowNET.Core/APIs/tf.graph.cs +++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs @@ -21,7 +21,7 @@ namespace Tensorflow public partial class tensorflow { public graph_util_impl graph_util => new graph_util_impl(); - + public GraphTransformer graph_transforms => new GraphTransformer(); public GraphKeys GraphKeys { get; } = new GraphKeys(); public void reset_default_graph() diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index c416c6e8..2a6b125b 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -46,6 +46,9 @@ namespace Tensorflow return gen_nn_ops.conv2d(parameters); } + public Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = null) + => gen_ctc_ops.ctc_greedy_decoder(inputs, sequence_length, merge_repeated: merge_repeated, name: name); + /// /// Computes dropout. /// diff --git a/src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs b/src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs new file mode 100644 index 00000000..5bc74351 --- /dev/null +++ b/src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class GraphTransformer + { + /// + /// Graph Transform Tool + /// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md + /// + /// GraphDef object containing a model to be transformed + /// the model inputs + /// the model outputs + /// transform names and parameters + /// + public GraphDef TransformGraph(GraphDef input_graph_def, + string[] inputs, + string[] outputs, + string[] transforms) + { + var input_graph_def_string = input_graph_def.ToString(); + var inputs_string = string.Join(",", inputs); + var outputs_string = string.Join(",", outputs); + var transforms_string = string.Join(",", transforms); + + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/ctc_ops.cs b/src/TensorFlowNET.Core/Operations/ctc_ops.cs new file mode 100644 index 00000000..07ed811d --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ctc_ops.cs @@ -0,0 +1,67 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class ctc_ops + { + /// + /// Performs greedy decoding on the logits given in inputs. + /// + /// + /// 3-D, shape: (max_time x batch_size x num_classes), the logits. + /// + /// + /// A vector containing sequence lengths, size (batch_size). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CTCGreedyDecoder'. + /// + /// + /// If True, merge repeated classes in output. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// decoded_indices : Indices matrix, size (total_decoded_outputs x 2), + /// of a SparseTensor<int64, 2>. The rows store: [batch, time]. + /// decoded_values : Values vector, size: (total_decoded_outputs), + /// of a SparseTensor<int64, 2>. The vector stores the decoded classes. + /// decoded_shape : Shape vector, size (2), of the decoded SparseTensor. + /// Values are: [batch_size, max_decoded_length]. + /// log_probability : Matrix, size (batch_size x 1), containing sequence + /// log-probabilities. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// A note about the attribute merge_repeated: if enabled, when + /// consecutive logits' maximum indices are the same, only the first of + /// these is emitted. Labeling the blank '*', the sequence "A B B * B B" + /// becomes "A B B" if merge_repeated = True and "A B B B B" if + /// merge_repeated = False. + /// + /// Regardless of the value of merge_repeated, if the maximum index of a given + /// time and batch corresponds to the blank, index (num_classes - 1), no new + /// element is emitted. + /// + public Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = null) + => gen_ctc_ops.ctc_greedy_decoder(inputs, sequence_length, merge_repeated: merge_repeated, name: name); + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_ctc_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ctc_ops.cs new file mode 100644 index 00000000..018a56bb --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_ctc_ops.cs @@ -0,0 +1,38 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public class gen_ctc_ops + { + public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + + public static Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = "CTCGreedyDecoder") + { + var op = _op_def_lib._apply_op_helper("CTCGreedyDecoder", name: name, args: new + { + inputs, + sequence_length, + merge_repeated + }); + /*var decoded_indices = op.outputs[0]; + var decoded_values = op.outputs[1]; + var decoded_shape = op.outputs[2]; + var log_probability = op.outputs[3];*/ + return op.outputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs similarity index 100% rename from src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs rename to src/TensorFlowNET.Core/Operations/gen_image_ops.cs diff --git a/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_io_ops.cs similarity index 100% rename from src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs rename to src/TensorFlowNET.Core/Operations/gen_io_ops.cs diff --git a/src/TensorFlowNet.Benchmarks/Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Benchmark.csproj index ad4e17e6..ca5338bd 100644 --- a/src/TensorFlowNet.Benchmarks/Benchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/Benchmark.csproj @@ -19,7 +19,8 @@ - + +