From d1e1e05546f883f245be3c869cad383753aac790 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 27 Oct 2019 10:46:49 -0500 Subject: [PATCH] inputs for rnn/while/TensorArrayReadV3 are incorrect #433 --- .../Operations/NnOps/rnn.cs | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 41516bb8..41a4622a 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using System; using System.Collections.Generic; using System.Linq; @@ -24,7 +25,7 @@ namespace Tensorflow.Operations { internal class rnn { - public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor, + public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, Tensor sequence_length = null, Tensor initial_state = null, TF_DataType dtype = TF_DataType.DtInvalid, int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) @@ -79,7 +80,7 @@ namespace Tensorflow.Operations /// /// /// - private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state, + private static (Tensor, Tensor) _dynamic_rnn_loop(RnnCell cell, Tensor inputs, Tensor initial_state, int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid) { var state = initial_state; @@ -170,11 +171,11 @@ namespace Tensorflow.Operations flat_input_i.dtype)); } - for (int i = 0; i < input_ta.Count; i++) + input_ta = zip(input_ta, flat_input).Select(x => { - var (ta, input_) = (input_ta[i], flat_input[i]); - ta.unstack(input_); - } + var (ta, input_) = (x.Item1, x.Item2); + return ta.unstack(input_); + }).ToList(); } // Make sure that we run at least 1 step, if necessary, to ensure @@ -192,11 +193,29 @@ namespace Tensorflow.Operations // Take a time step of the dynamic RNN. Func _time_step = (item) => { + Tensor[] input_t = null; + var (time1, output_ta_t, state1) = (item.time, item.output_ta_t, item.state); if (in_graph_mode) { - input_ta.Select(ta => ta.read(time)).ToArray(); + input_t = input_ta.Select(ta => ta.read(time1)).ToArray(); + // Restore some shape information + foreach (var (input_, shape) in zip(input_t, inputs_got_shape)) + input_.set_shape(shape[new Slice(1)]); + } + else + { + // input_t = tuple(ta[time.numpy()] for ta in input_ta) } + var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t); + // Keras RNN cells only accept state as list, even if it's a single tensor. + // var is_keras_rnn_cell = _is_keras_rnn_cell(cell); + (Tensor, Tensor) a = (null, null); + if (sequence_length != null) + throw new NotImplementedException("sequence_length != null"); + else + a = cell.__call__(input_t_t, state1); + return item; };