Browse Source

inputs for rnn/while/TensorArrayReadV3 are incorrect #433

tags/v0.12
Oceania2018 6 years ago
parent
commit
d1e1e05546
1 changed files with 26 additions and 7 deletions
  1. +26
    -7
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs

+ 26
- 7
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
@@ -24,7 +25,7 @@ namespace Tensorflow.Operations
{
internal class rnn
{
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor,
public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor,
Tensor sequence_length = null, Tensor initial_state = null,
TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
@@ -79,7 +80,7 @@ namespace Tensorflow.Operations
/// <param name="sequence_length"></param>
/// <param name="dtype"></param>
/// <returns></returns>
private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state,
private static (Tensor, Tensor) _dynamic_rnn_loop(RnnCell cell, Tensor inputs, Tensor initial_state,
int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
var state = initial_state;
@@ -170,11 +171,11 @@ namespace Tensorflow.Operations
flat_input_i.dtype));
}

for (int i = 0; i < input_ta.Count; i++)
input_ta = zip(input_ta, flat_input).Select(x =>
{
var (ta, input_) = (input_ta[i], flat_input[i]);
ta.unstack(input_);
}
var (ta, input_) = (x.Item1, x.Item2);
return ta.unstack(input_);
}).ToList();
}

// Make sure that we run at least 1 step, if necessary, to ensure
@@ -192,11 +193,29 @@ namespace Tensorflow.Operations
// Take a time step of the dynamic RNN.
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) =>
{
Tensor[] input_t = null;
var (time1, output_ta_t, state1) = (item.time, item.output_ta_t, item.state);
if (in_graph_mode)
{
input_ta.Select(ta => ta.read(time)).ToArray();
input_t = input_ta.Select(ta => ta.read(time1)).ToArray();
// Restore some shape information
foreach (var (input_, shape) in zip(input_t, inputs_got_shape))
input_.set_shape(shape[new Slice(1)]);
}
else
{
// input_t = tuple(ta[time.numpy()] for ta in input_ta)
}

var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t);
// Keras RNN cells only accept state as list, even if it's a single tensor.
// var is_keras_rnn_cell = _is_keras_rnn_cell(cell);
(Tensor, Tensor) a = (null, null);
if (sequence_length != null)
throw new NotImplementedException("sequence_length != null");
else
a = cell.__call__(input_t_t, state1);

return item;
};



Loading…
Cancel
Save