From 46b86e845f82318f2a2684c80d485f663e7c7b72 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Sun, 21 May 2023 01:37:08 +0800 Subject: [PATCH] Draft PR for RNN --- .../APIs/tf.control_flow.cs | 15 + .../Keras/ArgsDefinition/Rnn/RNNArgs.cs | 12 +- .../ArgsDefinition/Rnn/StackedRNNCellsArgs.cs | 3 +- .../NumPy/StateSizeWrapper.cs | 63 ++ .../Operations/NnOps/RNNCell.cs | 3 +- .../Operations/control_flow_ops.cs | 47 ++ src/TensorFlowNET.Core/Util/nest.py.cs | 44 ++ src/TensorFlowNET.Keras/BackendImpl.cs | 540 +++++++++++++++++- src/TensorFlowNET.Keras/Engine/Layer.cs | 4 +- .../Layers/Activation/ELU.cs | 2 +- .../Layers/Activation/Exponential.cs | 2 +- .../Layers/Activation/HardSigmoid.cs | 3 +- .../Layers/Activation/LeakyReLu.cs | 2 +- .../Layers/Activation/SELU.cs | 3 +- .../Layers/Activation/Softmax.cs | 3 +- .../Layers/Activation/Softplus.cs | 3 +- .../Layers/Activation/Softsign.cs | 3 +- .../Layers/Activation/Swish.cs | 3 +- .../Layers/Activation/Tanh.cs | 2 +- .../Layers/Attention/BaseDenseAttention.cs | 2 +- .../Layers/Attention/MultiHeadAttention.cs | 2 +- .../Layers/Convolution/Convolutional.cs | 2 +- src/TensorFlowNET.Keras/Layers/Core/Dense.cs | 2 +- .../Layers/Core/EinsumDense.cs | 2 +- .../Layers/Core/Embedding.cs | 2 +- .../Layers/Merging/Merge.cs | 2 +- .../Normalization/BatchNormalization.cs | 2 +- .../Normalization/LayerNormalization.cs | 2 +- .../Layers/Normalization/Normalization.cs | 2 +- .../Layers/Pooling/GlobalAveragePooling1D.cs | 2 +- .../Layers/Pooling/GlobalAveragePooling2D.cs | 2 +- .../Layers/Pooling/GlobalMaxPooling1D.cs | 2 +- .../Layers/Pooling/GlobalMaxPooling2D.cs | 2 +- .../Layers/Pooling/Pooling1D.cs | 2 +- .../Layers/Pooling/Pooling2D.cs | 2 +- .../Layers/Preprocessing/CategoryEncoding.cs | 2 +- .../Layers/Preprocessing/Rescaling.cs | 2 +- .../Layers/Preprocessing/Resizing.cs | 2 +- .../Layers/Regularization/Dropout.cs | 2 +- .../Layers/Reshaping/Cropping1D.cs | 2 +- .../Layers/Reshaping/Cropping2D.cs | 2 +- .../Layers/Reshaping/Cropping3D.cs | 2 +- .../Layers/Reshaping/Flatten.cs | 2 +- .../Layers/Reshaping/Permute.cs | 2 +- .../Layers/Reshaping/Reshape.cs | 2 +- .../Layers/Reshaping/UpSampling2D.cs | 2 +- .../Layers/Reshaping/ZeroPadding2D.cs | 2 +- src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs | 4 +- src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 422 +++++++++++++- .../Layers/Rnn/SimpleRNNCell.cs | 4 +- .../Layers/Rnn/StackedRNNCells.cs | 12 +- .../Layers/TensorFlowOpLayer.cs | 2 +- 52 files changed, 1187 insertions(+), 70 deletions(-) create mode 100644 src/TensorFlowNET.Core/NumPy/StateSizeWrapper.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs index 239487e0..578f23f9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs +++ b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs @@ -57,6 +57,21 @@ namespace Tensorflow new[] { loop_vars }); return results[0]; } + public (Tensor, List, Tensors, Tensors) while_loop(Func cond, + Func, Tensors, Tensors, (Tensor, List, Tensors, Tensors)> body, + (Tensor, List, Tensors, Tensors) loop_vars, + int parallel_iterations = 10) + => control_flow_ops.while_loop(cond, + body, + loop_vars); + + public (Tensor, List, Tensors) while_loop(Func cond, + Func, Tensors, (Tensor, List, Tensors)> body, + (Tensor, List, Tensors) loop_vars, + int parallel_iterations = 10) + => control_flow_ops.while_loop(cond, + body, + loop_vars); public Tensor[] while_loop(Func cond, Func body, diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs index 2585592c..911c6721 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs @@ -1,5 +1,9 @@ using Newtonsoft.Json; +using OneOf; using System.Collections.Generic; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.NumPy; namespace Tensorflow.Keras.ArgsDefinition.Rnn { @@ -7,11 +11,14 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn { public interface IRnnArgCell : ILayer { - object state_size { get; } + public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null); + public StateSizeWrapper state_size { get; set; } + public int output_size { get; set; } } [JsonProperty("cell")] // TODO: the cell should be serialized with `serialize_keras_object`. - public IRnnArgCell Cell { get; set; } = null; + public OneOf, IRnnArgCell> Cell { get; set; } + [JsonProperty("return_sequences")] public bool ReturnSequences { get; set; } = false; [JsonProperty("return_state")] @@ -25,6 +32,7 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn [JsonProperty("time_major")] public bool TimeMajor { get; set; } = false; // TODO: Add `num_constants` and `zero_output_for_mask`. + public bool ZeroOutputForMask { get; set; } = false; public Dictionary Kwargs { get; set; } = null; public int Units { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs index fdfadab8..dee7e8d3 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs @@ -1,10 +1,11 @@ using System.Collections.Generic; +using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; namespace Tensorflow.Keras.ArgsDefinition.Rnn { public class StackedRNNCellsArgs : LayerArgs { - public IList Cells { get; set; } + public IList Cells { get; set; } public Dictionary Kwargs { get; set; } = null; } } diff --git a/src/TensorFlowNET.Core/NumPy/StateSizeWrapper.cs b/src/TensorFlowNET.Core/NumPy/StateSizeWrapper.cs new file mode 100644 index 00000000..f2a9e5f2 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/StateSizeWrapper.cs @@ -0,0 +1,63 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Collections; + + +namespace Tensorflow.NumPy +{ + // Since state_size in RNN is a single integer or array of integer, so use StateSizeWrapper to hold it + public class StateSizeWrapper : IEnumerable + { + int[] _state_size; + public int[] state_size => _state_size; + + public StateSizeWrapper(int state_size) + { + _state_size = new int[] { state_size }; + } + + public StateSizeWrapper(params int[] state_size) + { + _state_size = state_size; + } + public StateSizeWrapper(IEnumerable state_size) + { + _state_size = state_size.ToArray(); + } + + public static implicit operator StateSizeWrapper(int[] state_size) + => new StateSizeWrapper(state_size); + + public static implicit operator StateSizeWrapper(int state_size) + => new StateSizeWrapper(state_size); + + public static implicit operator StateSizeWrapper((int, int) state_size) + => new StateSizeWrapper(state_size.Item1, state_size.Item2); + + public static implicit operator StateSizeWrapper(List v) + => new StateSizeWrapper(v); + public override string ToString() + { + return $"{state_size}"; + } + + public int this[int n] + { + get => n < 0 ? state_size[state_size.Length + n] : state_size[n]; + set => state_size[n] = value; + } + + public IEnumerator GetEnumerator() + { + return state_size.ToList().GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} + + diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index ecc9ca11..d49c8218 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -26,6 +26,7 @@ using Tensorflow.Operations; using Tensorflow.Train; using Tensorflow.Util; using static Tensorflow.Binding; +using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; namespace Tensorflow { @@ -50,7 +51,7 @@ namespace Tensorflow /// matching structure of Tensors having shape `[batch_size].concatenate(s)` /// for each `s` in `self.batch_size`. /// - public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell + public abstract class RnnCell : ILayer { /// /// Attribute that indicates whether the cell is a TF RNN cell, due the slight diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 862b636f..c59e5b16 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -698,6 +698,53 @@ namespace Tensorflow }); } + public static (Tensor, List, Tensors, Tensors) while_loop(Func cond, + Func, Tensors, Tensors, (Tensor, List, Tensors, Tensors)> body, + (Tensor, List, Tensors, Tensors) loop_vars, + int parallel_iterations = 10, + string name = null) + { + var executing_eagerly = tf.Context.executing_eagerly(); + if (!executing_eagerly) + { + throw new NotImplementedException(""); + } + + return tf_with(ops.name_scope("name", "while"), delegate + { + while ((bool)cond(loop_vars.Item1)) + { + loop_vars = body(loop_vars.Item1, loop_vars.Item2, loop_vars.Item3, loop_vars.Item4); + } + + return loop_vars; + }); + } + + public static (Tensor, List, Tensors) while_loop(Func cond, + Func, Tensors, (Tensor, List, Tensors)> body, + (Tensor, List, Tensors) loop_vars, + int parallel_iterations = 10, + string name = null) + { + var executing_eagerly = tf.Context.executing_eagerly(); + if (!executing_eagerly) + { + throw new NotImplementedException(""); + } + + return tf_with(ops.name_scope("name", "while"), delegate + { + while ((bool)cond(loop_vars.Item1)) + { + loop_vars = body(loop_vars.Item1, loop_vars.Item2, loop_vars.Item3); + } + + return loop_vars; + }); + } + + /// /// Repeat `body` while the condition `cond` is true. /// diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index eb94f4d0..2879fa8e 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -211,6 +211,28 @@ namespace Tensorflow.Util => arg is IEnumerable && !(arg is string) && !(arg is NDArray) && !(arg.GetType().IsGenericType && arg.GetType().GetGenericTypeDefinition() == typeof(HashSet<>)); + public static bool is_nested(object obj) + { + // Check if the object is an IEnumerable + if (obj is IEnumerable) + { + // If it is, check if it is a nested structure + foreach (object item in (IEnumerable)obj) + { + if (is_nested(item)) + { + return true; + } + } + return true; + } + else + { + // If it is not, return false + return false; + } + } + public static bool is_mapping(object arg) => arg is IDictionary; //# See the swig file (util.i) for documentation. @@ -263,7 +285,29 @@ namespace Tensorflow.Util } } + public static List FlattenTupple(object tuple) + { + List items = new List(); + var type = tuple.GetType(); + + if (type.GetInterface("ITuple") == null) + throw new ArgumentException("This is not a tuple!"); + foreach (var property in type.GetProperties()) + { + var value = property.GetValue(tuple); + if (property.PropertyType.GetInterface("ITuple") != null) + { + var subItems = FlattenTupple(value); + items.AddRange(subItems); + } + else + { + items.Add((T)value); + } + } + return items; + } //# See the swig file (util.i) for documentation. //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 80403ad6..da1d25c9 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -22,6 +22,9 @@ using Tensorflow.Functions; using Tensorflow.Graphs; using static Tensorflow.Binding; using static Tensorflow.Graphs.SubGraphUtility; +using Tensorflow.Util; +using Tensorflow.Operations; +using OneOf; namespace Tensorflow.Keras { @@ -65,7 +68,7 @@ namespace Tensorflow.Keras return; } var graph = v.Graph; - if(graph is null) + if (graph is null) { graph = get_graph(); } @@ -95,7 +98,7 @@ namespace Tensorflow.Keras { if (_GRAPH == null) _GRAPH = new FuncGraph("keras_graph"); - + return _GRAPH; } return ops.get_default_graph(); @@ -105,7 +108,7 @@ namespace Tensorflow.Keras { if (_CURRENT_SCRATCH_GRAPH == null) _CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph"); - + return _CURRENT_SCRATCH_GRAPH; } @@ -230,16 +233,16 @@ namespace Tensorflow.Keras { if (outputs[0].op.type == "Const") return tensor_util.constant_value(outputs); - + var source_graph = outputs.graph; var exec_graph = _scratch_graph(); var global_graph = get_graph(); if (source_graph == global_graph && exec_graph != global_graph) { - var lifted_map = lift_to_graph(outputs, exec_graph, - new List(), - add_sources: true, - handle_captures: true, + var lifted_map = lift_to_graph(outputs, exec_graph, + new List(), + add_sources: true, + handle_captures: true, base_graph: source_graph); } if (outputs[0].op.type == "Placeholder" @@ -250,7 +253,7 @@ namespace Tensorflow.Keras exec_graph.as_default(); exec_graph.Inputs = exec_graph.internal_captures; exec_graph.Outputs = outputs; - + var graph_fn = new ConcreteFunction(exec_graph); _CURRENT_SCRATCH_GRAPH = null; @@ -370,7 +373,7 @@ namespace Tensorflow.Keras /// /// /// - public Tensor resize_images(Tensor x, int height_factor, int width_factor, + public Tensor resize_images(Tensor x, int height_factor, int width_factor, string data_format, string interpolation = "nearest") { var (rows, cols) = (0, 0); @@ -412,7 +415,7 @@ namespace Tensorflow.Keras /// public Tensor concatenate(Tensors tensors, int axis = -1) { - if(axis < 0) + if (axis < 0) { var rank = tensors[0].ndim; if (rank > -1) @@ -450,5 +453,520 @@ namespace Tensorflow.Keras return x; } + + public static (Tensors, Tensors) convert_inputs_if_ragged(OneOf inputs) + { + throw new NotImplementedException(); + } + + // + public static (Tensors, Tensors, Tensors) rnn( + Func step_function, // args:inputs, states, return:output, new_states + Tensors inputs, // inputs is a tuple of tensors (one per input sequence) + Tensors initial_states, + bool go_backwards = false, + Tensor? mask = null, + Tensors? constants = null, + bool unroll = false, + Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not + bool time_major = false, + bool zero_output_for_mask = false, + bool return_all_outputs = true) + { + + Tensors swap_batch_timestep(Tensors input_t) + { + var axes = Enumerable.Range(0, input_t.rank).ToArray(); + axes[0] = 1; + axes[1] = 0; + return tf.transpose(input_t, axes); + } + + if (!time_major) + { + inputs = nest.map_structure(swap_batch_timestep, inputs); + } + + var flatted_inptus = nest.flatten(inputs); + var time_steps = flatted_inptus[0].shape[0]; + var batch = flatted_inptus[0].shape[1]; + var time_step_t = tf.shape(flatted_inptus[0])[0]; + + foreach (var input_ in flatted_inptus) + { + input_.shape.with_rank_at_least(3); + } + + if (mask != null) + { + if (mask.dtype != TF_DataType.TF_BOOL) + { + mask = tf.cast(mask, TF_DataType.TF_BOOL); + } + + if (mask.rank == 2) + { + mask = tf.expand_dims(mask, -1); + } + + if (!time_major) + { + mask = swap_batch_timestep(mask); + } + + } + + if (constants == null) + { + constants = new List(); + } + + // tf.where needs its condition tensor to be the same shape as its two + // result tensors, but in our case the condition (mask) tensor is + // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. + // So we need to broadcast the mask to match the shape of inputs. + // That's what the tile call does, it just repeats the mask along its + // second dimension n times. + + Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) + { + if (nest.is_nested(mask_t)) + { + throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}"); + } + + if (nest.is_nested(input_t)) + { + throw new ValueError($"input_t is expected to be tensor, but got {input_t}"); + } + + var rank_diff = input_t.rank - mask_t.rank; + for (int i = 0; i < rank_diff; i++) + { + mask_t = tf.expand_dims(mask_t, -1); + } + var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank)); + return tf.tile(mask_t, multiples); + } + + Tensors outputs = new Tensors(); + Tensors output_time_zero = new Tensors(); + Tensors last_output = new Tensors(); + Tensors new_states = new Tensors(); + if (unroll) + { + if (time_steps == 0) + { + throw new ValueError("Unrolling requires a fixed number of timesteps."); + } + + // Process the input tensors. The input tensor need to be split on the + // time_step dim, and reverse if go_backwards is True. In the case of + // nested input, the input is flattened and then transformed + // individually. The result of this will be a tuple of lists, each of + // the item in tuple is list of the tensor with shape (batch, feature) + + + // TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple + //var states = Tuple.Create(initial_states); + var states = initial_states; + + var successive_states = new Tensors(); + var successive_outputs = new Tensors(); + + // Process the input tensors. The input tensor need to be split on the + // time_step dim, and reverse if go_backwards is True. In the case of + // nested input, the input is flattened and then transformed + // individually. The result of this will be a tuple of lists, each of + // the item in tuple is list of the tensor with shape (batch, feature) + + + + + Tensors _process_single_input_t(Tensors input_t) + { + input_t = tf.unstack(input_t); // unstack for time_step dim + if (go_backwards) + { + input_t.Reverse(); + } + return input_t; + } + + // TODO(Wanglongzhi2001) + Tensors processed_input; + if (nest.is_nested(inputs)) + { + processed_input = nest.map_structure(_process_single_input_t, inputs); + } + else + { + processed_input = _process_single_input_t(inputs); + } + + object _get_input_tensor(int time) + { + List inp = new List(); + foreach (var t_ in processed_input) + { + inp.Add(t_[time]); + } + return nest.pack_sequence_as(inputs, inp); + } + + if (mask != null) + { + var mask_list = tf.unstack(mask); + if (go_backwards) + { + mask_list.Reverse(); + } + + for (int i = 0; i < time_steps; i++) + { + // TODO(Wanglongzhi2001),deal with _get_input_tensor + var inp = _get_input_tensor(i); + var mask_t = mask_list[i]; + // TODO + var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); + + var tiled_mask_t = _expand_mask(mask_t, output); + + Tensors prev_output; + if (successive_outputs == null) + { + prev_output = tf.zeros_like(output); + } + else + { + prev_output = successive_outputs[successive_outputs.Length - 1]; + } + + output = tf.where(tiled_mask_t, output, prev_output); + + //var flat_states = nest.flatten(states); + //var flat_new_states = nest.flatten(newStates); + var flat_states = states.ToList(); + var flat_new_states = newStates.ToList(); + + var tiledMaskT = flat_states + .Select(s => _expand_mask(mask_t, s)) + .ToArray(); + var tuple = Tuple.Create(tiledMaskT); + + List flat_final_states = new List(); + foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) + { + flat_final_states.Add(tf.where(m, s, ps)); + } + + states = (Tensors)nest.pack_sequence_as(states, flat_final_states); + if (return_all_outputs) + { + successive_outputs.Add(output); + successive_states.Add(states); + } + else + { + successive_outputs = new Tensors { output }; + successive_states = new Tensors { states }; + } + + } + last_output = successive_outputs[successive_outputs.Length - 1]; + new_states = successive_states[successive_states.Length - 1]; + outputs = tf.stack(successive_outputs); + + if (zero_output_for_mask) + { + last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); + outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); + } + else // mask is null + { + for (int i = 0; i < time_steps; i++) + { + var inp = _get_input_tensor(i); + var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); + states = newStates; + + if (return_all_outputs) + { + successive_outputs.Add(output); + successive_states.Add(newStates); + } + else + { + successive_outputs = new Tensors { output }; + successive_states = new Tensors { newStates }; + } + } + last_output = successive_outputs[successive_outputs.Length - 1]; + new_states = successive_states[successive_states.Length - 1]; + outputs = tf.stack(successive_outputs); + } + } + } + else // unroll == false + { + var states = initial_states; + // Create input tensor array, if the inputs is nested tensors, then it + // will be flattened first, and tensor array will be created one per + // flattened tensor. + var input_ta = new List(); + for (int i = 0; i < flatted_inptus.Count; i++) + { + input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); + } + + // Get the time(0) input and compute the output for that, the output will + // be used to determine the dtype of output tensor array. Don't read from + // input_ta due to TensorArray clear_after_read default to True. + var inps = new Tensors(); + foreach (var inp in flatted_inptus) + { + inps.Add(inp[0]); + } + var input_time_zero = nest.pack_sequence_as(inputs, inps); + + // output_time_zero is used to determine the cell output shape and its + // dtype. the value is discarded. + (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); + + var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); + var output_ta = new List(); + for (int i = 0; i < output_time_zero.ToList().Count; i++) + { + var Out = output_time_zero.ToList()[i]; + output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); + } + + var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); + + + + Func? masking_fn; + Func? compute_masked_output = null; + if (mask != null) + { + if (go_backwards) + { + mask = tf.reverse(mask, axis: new[] { 0 }); + } + var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); + mask_ta = mask_ta.unstack(mask); + + masking_fn = (time) => + { + return mask_ta.read(time); + }; + + compute_masked_output = (mask_t, flat_out, flat_mask) => + { + var tiled_mask_t = new Tensors(); + foreach (var o in flat_out) + { + tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); + } + + Tensors res = new Tensors(); + foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) + { + res.Add(tf.where(m, o, fm)); + } + return res; + }; + } + // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? + else if (input_length is Tensor) + { + if (go_backwards) + { + var max_len = tf.reduce_max(input_length, axis: 0); + var rev_input_length = tf.subtract(max_len - 1, input_length); + + masking_fn = (time) => + { + return tf.less(rev_input_length, time); + }; + } + else + { + masking_fn = (time) => + { + return tf.greater(input_length, time); + }; + } + + compute_masked_output = (mask_t, flat_out, flat_mask) => + { + var res = new List(); + foreach (var (o, zo) in zip(flat_out, flat_mask)) + { + res.Add(tf.where(mask_t, o, zo)); + } + return res; + }; + } + else + { + masking_fn = null; + } + + + if (masking_fn != null) + { + // Mask for the T output will be base on the output of T - 1. In the + // case T = 0, a zero filled tensor will be used. + var flat_zero_output = new Tensors(); + foreach (var o in nest.flatten(output_time_zero)) + { + flat_zero_output.Add(tf.zeros_like(o)); + } + + + (Tensor, List, Tensors, Tensors) _step(Tensor time, List output_ta_t, Tensors prev_output, Tensors states) + { + /* + RNN step function. + Args: + time: Current timestep value. + output_ta_t: TensorArray. + prev_output: tuple of outputs from time - 1. + *states: List of states. + Returns: + Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` + */ + + var current_input = input_ta.Select(x => x.read(time)).ToList(); + // maybe set shape + // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + current_input = (List)nest.pack_sequence_as(inputs, current_input); + var mask_t = masking_fn(time); + var (output, new_states) = step_function(current_input, new Tensors { states, constants }); + // mask output + //var flat_output = nest.flatten(output); + var flat_output = output.ToList(); + + var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); + + // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type + var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); + + // mask states + var flat_state = states.ToList(); + var flat_new_state = new_states.ToList(); + + foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + { + if (new_state is Tensor) + { + new_state.set_shape(state.shape); + } + } + + var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); + new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); + + var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + var Output_ta_t = new List(); + // TODO(Wanglongzhi2001),deal with zip output_ta_t + foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) + { + Output_ta_t.Add(ta.write(ta_index_to_write, Out)); + } + + + + //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); + + + return (time + 1, Output_ta_t, flat_new_output, new_states); + + } + Func cond = (time) => (time < time_step_t); + + var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); + new_states = final_outputs.Item4; + output_ta = final_outputs.Item2; + + } + else + { + (Tensor, List, Tensors) _step(Tensor time, List output_ta_t, Tensors states) + { + var current_input = input_ta.Select(x => x.read(time)).ToList(); + // maybe set shape + // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + current_input = (List)nest.pack_sequence_as(inputs, current_input); + var (output, new_states) = step_function(current_input, new Tensors { states, constants }); + var flat_state = states.ToList(); + var flat_new_state = new_states.ToList(); + foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + { + if (new_state is Tensor) + { + new_state.set_shape(state.shape); + } + } + var flat_output = output.ToList(); + var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + var Output_ta_t = new List(); + foreach (var (ta, out_) in zip(output_ta_t, flat_output)) + { + Output_ta_t.Add(ta.write(ta_index_to_write, out_)); + } + + new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); + return (time + 1, Output_ta_t, new_states); + } + Func cond = (time) => (time < time_step_t); + var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); + new_states = final_outputs.Item3; + output_ta = final_outputs.Item2; + + } + //Tensors outputs = new Tensors(); + foreach (var o in output_ta) + { + outputs.Add(o.stack()); + } + foreach (var o in outputs) + { + last_output.Add(o[-1]); + } + outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); + last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); + + } + + Func set_shape; + set_shape = (output_) => + { + if (output_ is Tensor) + { + var shape = output_.shape.as_int_list(); + if (return_all_outputs) + { + shape[0] = (int)time_steps; + } + else + { + shape[0] = 1; + } + shape[1] = (int)batch; + output_.set_shape(new Tensor(shape)); + } + return output_; + }; + + var Outputs = (Tensors)nest.map_structure(set_shape, outputs); + if (!time_major) + { + Outputs = nest.map_structure(swap_batch_timestep, outputs); + } + return (last_output, Outputs, new_states); + + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 5942efd9..4216c725 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -332,9 +332,9 @@ namespace Tensorflow.Keras.Engine /// /// /// - protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected virtual Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - if(ReplacedCall is not null) + if (ReplacedCall is not null) { return ReplacedCall(inputs); } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 739c0d56..9fb8781e 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers { base.build(input_shape); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor output = inputs; output = tf.where(output > 0f, output, diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index 17636302..2f618f63 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers { { base.build(input_shape); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor output = inputs; return tf.exp(output); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs b/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs index b498d1b9..efea135b 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs @@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { public HardSigmoid ( LayerArgs args ) : base(args) { // hard sigmoid has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor x = inputs; return tf.clip_by_value( tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs index 1fbbf4ea..feb98a0b 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { return tf.nn.leaky_relu(inputs, alpha: alpha); } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index 53101fbb..b444e338 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -22,7 +22,8 @@ namespace Tensorflow.Keras.Layers { } base.build(input_shape); } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor output = inputs; return tf.where(output > 0f, tf.multiply(scale, output), diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs index 3ffae27f..62d2461e 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs @@ -11,7 +11,8 @@ namespace Tensorflow.Keras.Layers { public Softmax ( SoftmaxArgs args ) : base(args) { axis = args.axis; } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) : inputs; Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs index e82b0198..13dfad4e 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs @@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { public Softplus ( LayerArgs args ) : base(args) { // Softplus has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor x = inputs; return tf.log( tf.add(tf.exp(x), 1f)); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs index 59329fd4..9933db5f 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs @@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { public Softsign ( LayerArgs args ) : base(args) { // Softsign has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor x = inputs; // x / (abs(x) + 1) return tf.div(x, tf.add(1f, tf.abs(x))); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs b/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs index 1dcb92b3..727d385d 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs @@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { public Swish ( LayerArgs args ) : base(args) { // Swish has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor x = inputs; // x / (1 + exp(-x)) diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs b/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs index 99b80394..802b894e 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers { // Tanh has no arguments } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor x = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs index 1348e19c..fe37d860 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs @@ -114,7 +114,7 @@ namespace Tensorflow.Keras.Layers return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensors _inp; Tensors _mask = null; diff --git a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs index 701724d5..f3fee090 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs @@ -252,7 +252,7 @@ namespace Tensorflow.Keras.Layers return (attention_output, attention_scores); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensors _inp; Tensor _mask = null; diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs index c575362c..cf0c6d2b 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -103,7 +103,7 @@ namespace Tensorflow.Keras.Layers _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var outputs = _convolution_op.Apply(inputs, kernel.AsTensor()); if (use_bias) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index aa6617dd..f574fd53 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -69,7 +69,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor outputs = null; var rank = inputs.rank; diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs index fb604f77..9aacf8f1 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs @@ -189,7 +189,7 @@ namespace Tensorflow.Keras.Layers // return new dict(base_config.items().ToList() + config.items().ToList()); //} - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var ret = tf.linalg.einsum(this.equation, (inputs, this.kernel.AsTensor())); if (this.bias != null) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs index 9487a7d0..6e074978 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -66,7 +66,7 @@ namespace Tensorflow.Keras.Layers _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs index 7df654ee..2d7a1e7d 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { return _merge_function(inputs); } diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index d02d2509..2af14cc7 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -146,7 +146,7 @@ namespace Tensorflow.Keras.Layers return false; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor outputs = null; var training_tensor = training == null diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs index e90c0402..e708d6a8 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -101,7 +101,7 @@ namespace Tensorflow.Keras.Layers return input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor outputs = null; var inputs_dtype = inputs.dtype.as_base_dtype(); diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs index a65154bf..978d1029 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs @@ -157,7 +157,7 @@ namespace Tensorflow.Keras.Layers base.adapt(data, batch_size: batch_size, steps: steps); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (_args.Invert) { diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs index d62fb63a..21a21406 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (data_format == "channels_last") return math_ops.reduce_mean(inputs, 1, false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs index 000e4b8b..e03050a9 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (data_format == "channels_last") return math_ops.reduce_mean(inputs, (1, 2), false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs index 2de4671c..1a8f06dd 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (data_format == "channels_last") return math_ops.reduce_max(inputs, 1, false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs index b7e2c945..9ce002f0 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (data_format == "channels_last") return math_ops.reduce_max(inputs, (1, 2), false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs index a2f4c51b..65c6130d 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers input_spec = new InputSpec(ndim: 3); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { int pad_axis = args.DataFormat == "channels_first" ? 2 : 3; inputs = tf.expand_dims(inputs, pad_axis); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs index 27032255..4804d0ab 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers input_spec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { int[] pool_shape; int[] strides; diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs index 5620a916..a7e1fd19 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs @@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var depth = args.NumTokens; var max_value = tf.reduce_max(inputs); diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs index 5fc581af..99194ca6 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { scale = constant_op.constant(args.Scale, args.DType); offset = constant_op.constant(args.Offset, args.DType); diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs index 603e2b07..67e4b464 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { return image_ops_impl.resize_images_v2(inputs, new[] { args.Height, args.Width }, method: args.Interpolation); } diff --git a/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs index aa3a92a4..696ab5b9 100644 --- a/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs +++ b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs @@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (training == null) training = false; diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs index 9ead15cb..cf93c169 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs @@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Layers.Reshaping _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor output = inputs; if (output.rank != 3) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs index 087d59a1..7872b0b0 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers.Reshaping built = true; _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor output = inputs; if (output.rank != 4) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs index 04a1af60..5bc2433b 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers.Reshaping _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor output = inputs; if (output.rank != 5) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs index 539b5f62..8ff34134 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs @@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Layers _channels_first = args.DataFormat == "channels_first"; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (_channels_first) { diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs index e391775c..79f0b569 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Keras.Layers { built = true; _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor outputs = inputs; return tf.transpose(outputs, new Axis(permute)); diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs index 92a772f3..8a4e4d5f 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var shapes = new List(); shapes.Add(array_ops.shape(inputs)[0]); diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs index 8314151f..7e926dee 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers inputSpec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { return keras.backend.resize_images(inputs, size[0], size[1], diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs index 7c87100a..c68def38 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs @@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { return keras.backend.spatial_2d_padding(inputs, padding: padding, diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs index 59555e62..530c409e 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs @@ -26,9 +26,9 @@ namespace Tensorflow.Keras.Layers.Rnn .ToArray(); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - return base.Call(inputs, state: state, training: training); + return base.Call(inputs, initial_state: initial_state, training: training); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 310e8057..7bd4047a 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -1,9 +1,15 @@ using System; +using System.Collections; using System.Collections.Generic; -using Tensorflow.Keras.ArgsDefinition; +using System.Reflection; +using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Util; +using OneOf; +using OneOf.Types; +using Tensorflow.Common.Extensions; // from tensorflow.python.distribute import distribution_strategy_context as ds_context; namespace Tensorflow.Keras.Layers.Rnn @@ -19,11 +25,46 @@ namespace Tensorflow.Keras.Layers.Rnn protected IVariableV1 kernel; protected IVariableV1 bias; protected ILayer cell; + public RNN(RNNArgs args) : base(PreConstruct(args)) { this.args = args; SupportsMasking = true; + // if is StackedRnncell + if (args.Cell.IsT0) + { + cell = new StackedRNNCells(new StackedRNNCellsArgs + { + Cells = args.Cell.AsT0, + }); + } + else + { + cell = args.Cell.AsT1; + } + + + + + + Type type = cell.GetType(); + MethodInfo methodInfo = type.GetMethod("Call"); + if (methodInfo == null) + { + throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. "); + } + + PropertyInfo propertyInfo = type.GetProperty("state_size"); + if (propertyInfo == null) + { + throw new ValueError(@"The RNN cell should have a `state_size` attribute"); + } + + + + // get input_shape + this.args = PreConstruct(args); // The input shape is unknown yet, it could have nested tensor inputs, and // the input spec will be the list of specs for nested inputs, the structure // of the input_spec will be the same as the input. @@ -37,17 +78,384 @@ namespace Tensorflow.Keras.Layers.Rnn //} } + // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) + // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape + public object States + { + get + { + if (_states == null) + { + var state = nest.map_structure(x => null, cell.state_size); + return nest.is_nested(state) ? state : new Tensors { state }; + } + return _states; + } + set { _states = value; } + } + + private OneOf> compute_output_shape(Shape input_shape) + { + var batch = input_shape[0]; + var time_step = input_shape[1]; + if (args.TimeMajor) + { + (batch, time_step) = (time_step, batch); + } + + // state_size is a array of ints or a positive integer + var state_size = cell.state_size; + + + // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor + Func _get_output_shape; + _get_output_shape = (flat_output_size) => + { + var output_dim = flat_output_size.as_int_list(); + Shape output_shape; + if (args.ReturnSequences) + { + if (args.TimeMajor) + { + output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim)); + } + else + { + output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim)); + + } + } + else + { + output_shape = new Shape(new int[] { (int)batch }.concat(output_dim)); + } + return output_shape; + }; + + Shape output_shape; + if (cell.output_size != 0) + { + output_shape = nest.map_structure(_get_output_shape, cell.output_size); + // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 + output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); + } + else + { + output_shape = _get_output_shape(state_size[0]); + } + + if (args.ReturnState) + { + Func _get_state_shape; + _get_state_shape = (flat_state) => + { + var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); + return new Shape(state_shape); + }; + var state_shape = _get_state_shape(new Shape(state_size.ToArray())); + + return new List { output_shape, state_shape }; + } + else + { + return output_shape; + } + } + + private Tensors compute_mask(Tensors inputs, Tensors mask) + { + // Time step masks must be the same for each input. + // This is because the mask for an RNN is of size [batch, time_steps, 1], + // and specifies which time steps should be skipped, and a time step + // must be skipped for all inputs. + + mask = nest.flatten(mask)[0]; + var output_mask = args.ReturnSequences ? mask : null; + if (args.ReturnState) + { + var state_mask = new List(); + for (int i = 0; i < len(States); i++) + { + state_mask.Add(null); + } + return new List { output_mask }.concat(state_mask); + } + else + { + return output_mask; + } + + + } + public override void build(KerasShapesWrapper input_shape) { + object get_input_spec(Shape shape) + { + var input_spec_shape = shape.as_int_list(); + + var (batch_index, time_step_index) = args.TimeMajor ? (1, 0) : (0, 1); + if (!args.Stateful) + { + input_spec_shape[batch_index] = -1; + } + input_spec_shape[time_step_index] = -1; + return new InputSpec(shape: input_spec_shape); + } + + Shape get_step_input_shape(Shape shape) + { + + // return shape[1:] if self.time_major else (shape[0],) + shape[2:] + if (args.TimeMajor) + { + return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray(); + } + else + { + return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray()); + } + + + } + + object get_state_spec(Shape shape) + { + var state_spec_shape = shape.as_int_list(); + // append bacth dim + state_spec_shape = new int[] { -1 }.concat(state_spec_shape); + return new InputSpec(shape: state_spec_shape); + + } + + // Check whether the input shape contains any nested shapes. It could be + // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from + // numpy inputs. + + if (!cell.Built) { cell.build(input_shape); } } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + // inputs: Tensors + // mask: Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked + // training: bool + // initial_state: List of initial state tensors to be passed to the first call of the cell + // constants: List of constant tensors to be passed to the cell at each timestep + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - return base.Call(inputs, state, training); + //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); + //bool is_ragged_input = row_length != null; + //_validate_args_if_ragged(is_ragged_input, mask); + var (inputs_processed, initial_state_processed, constants_processed) = _process_inputs(inputs, initial_state, constants); + + _maybe_reset_cell_dropout_mask(cell); + if (cell is StackedRNNCells) + { + foreach (var cell in ((StackedRNNCells)cell).Cells) + { + _maybe_reset_cell_dropout_mask(cell); + } + } + + if (mask != null) + { + // Time step masks must be the same for each input. + //mask = nest.flatten(mask)[0]; + mask = mask[0]; + } + + + Shape input_shape; + if (nest.is_nested(initial_state_processed)) + { + // In the case of nested input, use the first element for shape check + // input_shape = nest.flatten(inputs)[0].shape; + input_shape = inputs[0].shape; + } + else + { + input_shape = inputs.shape; + } + + var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1]; + + if (args.Unroll && timesteps != null) + { + throw new ValueError( + "Cannot unroll a RNN if the " + + "time dimension is undefined. \n" + + "- If using a Sequential model, " + + "specify the time dimension by passing " + + "an `input_shape` or `batch_input_shape` " + + "argument to your first layer. If your " + + "first layer is an Embedding, you can " + + "also use the `input_length` argument.\n" + + "- If using the functional API, specify " + + "the time dimension by passing a `shape` " + + "or `batch_shape` argument to your Input layer." + ); + } + + // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) + var cell_call_fn = cell.Call; + Func step; + if (constants != null) + { + ParameterInfo[] parameters = cell_call_fn.GetMethodInfo().GetParameters(); + bool hasParam = parameters.Any(p => p.Name == "constants"); + if (!hasParam) + { + throw new ValueError( + $"RNN cell {cell} does not support constants." + + $"Received: constants={constants}"); + } + + step = (inputs, states) => + { + // constants = states[-self._num_constants :] + constants = states.numpy()[new Slice(states.Length - _num_constants, states.Length)]; + // states = states[: -self._num_constants] + states = states.numpy()[new Slice(0, states.Length - _num_constants)]; + // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) + states = states.Length == 1 ? states[0] : states; + var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); + if (!nest.is_nested(new_states)) + { + return (output, new Tensors { new_states }); + } + return (output, new_states); + }; + } + else + { + step = (inputs, states) => + { + // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) + states = states.Length == 1 ? states[0] : states; + var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); + if (!nest.is_nested(new_states)) + { + return (output, new Tensors { new_states }); + } + return (output, new_states); + }; + } + + var (last_output, outputs, states) = BackendImpl.rnn(step, + inputs, + initial_state, + constants: constants, + go_backwards: args.GoBackwards, + mask: mask, + unroll: args.Unroll, + input_length: row_length != null ? row_length : new Tensor(timesteps), + time_major: args.TimeMajor, + zero_output_for_mask: args.ZeroOutputForMask, + return_all_outputs: args.ReturnSequences); + + if (args.Stateful) + { + throw new NotImplementedException("this argument havn't been developed!"); + } + + Tensors output = new Tensors(); + if (args.ReturnSequences) + { + throw new NotImplementedException("this argument havn't been developed!"); + + } + else + { + output = last_output; + } + + if (args.ReturnState) + { + + foreach (var state in states) + { + output.Add(state); + } + return output; + } + else + { + return output; + } + } + + private (Tensors, Tensors, Tensors) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants) + { + bool IsSequence(object obj) + { + // Check if the object is an IEnumerable + if (obj is IEnumerable) + { + // If it is, check if it is a tuple + if (!(obj is Tuple)) + { + return true; + } + } + // If it is not, return false + return false; + } + + if (IsSequence(input)) + { + if (_num_constants != 0) + { + initial_state = inputs[new Slice(1, len(inputs))]; + } + else + { + initial_state = inputs[new Slice(1, len(inputs) - _num_constants)]; + } + if (len(initial_state) == 0) + initial_state = null; + inputs = inputs[0]; + } + + if (args.Stateful) + { + throw new NotImplementedException("argument stateful has not been implemented!"); + + } + + return (inputs, initial_state, constants); + + } + + private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) + { + if (is_ragged_input) + { + if (args.Unroll) + { + throw new ValueError("The input received contains RaggedTensors and does " + + "not support unrolling. Disable unrolling by passing " + + "`unroll=False` in the RNN Layer constructor."); + } + if (mask != null) + { + throw new ValueError($"The mask that was passed in was {mask}, which " + + "cannot be applied to RaggedTensor inputs. Please " + + "make sure that there is no mask injected by upstream " + + "layers."); + } + } + } + + void _maybe_reset_cell_dropout_mask(ILayer cell) + { + //if (cell is DropoutRNNCellMixin) + //{ + // cell.reset_dropout_mask(); + // cell.reset_recurrent_dropout_mask(); + //} } private static RNNArgs PreConstruct(RNNArgs args) @@ -77,6 +485,10 @@ namespace Tensorflow.Keras.Layers.Rnn return args; } + public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null) + { + throw new NotImplementedException(); + } public RNN New(LayerRnnCell cell, bool return_sequences = false, bool return_state = false, @@ -95,7 +507,7 @@ namespace Tensorflow.Keras.Layers.Rnn TimeMajor = time_major }); - public RNN New(IList cell, + public RNN New(IList cell, bool return_sequences = false, bool return_state = false, bool go_backwards = false, @@ -125,7 +537,7 @@ namespace Tensorflow.Keras.Layers.Rnn } // Check whether the state_size contains multiple states. - public static bool _is_multiple_state(object state_size) + public static bool is_multiple_state(object state_size) { var myIndexerProperty = state_size.GetType().GetProperty("Item"); return myIndexerProperty != null diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 46061b21..86985d7e 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -42,9 +42,9 @@ namespace Tensorflow.Keras.Layers.Rnn built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - return base.Call(inputs, state, training); + return base.Call(inputs, initial_state, training); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index 20962df1..8e67f8e8 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -2,15 +2,16 @@ using System.Collections.Generic; using System.ComponentModel; using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.ArgsDefinition.Rnn; +using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Keras.ArgsDefinition.Rnn; namespace Tensorflow.Keras.Layers.Rnn { - public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell + public class StackedRNNCells : Layer { - public IList Cells { get; set; } + public IList Cells { get; set; } public bool reverse_state_order; public StackedRNNCells(StackedRNNCellsArgs args) : base(args) @@ -51,7 +52,7 @@ namespace Tensorflow.Keras.Layers.Rnn { return lastCell.output_size; } - else if (RNN._is_multiple_state(lastCell.state_size)) + else if (RNN.is_multiple_state(lastCell.state_size)) { // return ((dynamic)Cells[-1].state_size)[0]; throw new NotImplementedException(""); @@ -63,6 +64,7 @@ namespace Tensorflow.Keras.Layers.Rnn } } + public object get_initial_state() { throw new NotImplementedException(); @@ -80,7 +82,7 @@ namespace Tensorflow.Keras.Layers.Rnn // return tuple(initial_states) } - public object call() + public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { throw new NotImplementedException(); // def call(self, inputs, states, constants= None, training= None, ** kwargs): diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs index 1ac4a277..6dfc089b 100644 --- a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (tf.Context.executing_eagerly()) return DeFunCall(inputs);