| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| @@ -24,7 +25,7 @@ namespace Tensorflow.Operations | |||
| { | |||
| internal class rnn | |||
| { | |||
| public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor, | |||
| public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, | |||
| Tensor sequence_length = null, Tensor initial_state = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | |||
| @@ -79,7 +80,7 @@ namespace Tensorflow.Operations | |||
| /// <param name="sequence_length"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <returns></returns> | |||
| private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state, | |||
| private static (Tensor, Tensor) _dynamic_rnn_loop(RnnCell cell, Tensor inputs, Tensor initial_state, | |||
| int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| var state = initial_state; | |||
| @@ -170,11 +171,11 @@ namespace Tensorflow.Operations | |||
| flat_input_i.dtype)); | |||
| } | |||
| for (int i = 0; i < input_ta.Count; i++) | |||
| input_ta = zip(input_ta, flat_input).Select(x => | |||
| { | |||
| var (ta, input_) = (input_ta[i], flat_input[i]); | |||
| ta.unstack(input_); | |||
| } | |||
| var (ta, input_) = (x.Item1, x.Item2); | |||
| return ta.unstack(input_); | |||
| }).ToList(); | |||
| } | |||
| // Make sure that we run at least 1 step, if necessary, to ensure | |||
| @@ -192,11 +193,29 @@ namespace Tensorflow.Operations | |||
| // Take a time step of the dynamic RNN. | |||
| Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) => | |||
| { | |||
| Tensor[] input_t = null; | |||
| var (time1, output_ta_t, state1) = (item.time, item.output_ta_t, item.state); | |||
| if (in_graph_mode) | |||
| { | |||
| input_ta.Select(ta => ta.read(time)).ToArray(); | |||
| input_t = input_ta.Select(ta => ta.read(time1)).ToArray(); | |||
| // Restore some shape information | |||
| foreach (var (input_, shape) in zip(input_t, inputs_got_shape)) | |||
| input_.set_shape(shape[new Slice(1)]); | |||
| } | |||
| else | |||
| { | |||
| // input_t = tuple(ta[time.numpy()] for ta in input_ta) | |||
| } | |||
| var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t); | |||
| // Keras RNN cells only accept state as list, even if it's a single tensor. | |||
| // var is_keras_rnn_cell = _is_keras_rnn_cell(cell); | |||
| (Tensor, Tensor) a = (null, null); | |||
| if (sequence_length != null) | |||
| throw new NotImplementedException("sequence_length != null"); | |||
| else | |||
| a = cell.__call__(input_t_t, state1); | |||
| return item; | |||
| }; | |||