From bad4e8160570fea5c786bd17b93c2f246a92aa36 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Fri, 2 Jun 2023 20:55:22 +0800 Subject: [PATCH] update draft pr for RNN --- .../Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs | 5 +- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 6 + .../Keras/Layers/ILayersApi.cs | 8 + .../Operations/NnOps/RNNCell.cs | 7 + src/TensorFlowNET.Core/Util/nest.py.cs | 68 +- src/TensorFlowNET.Keras/BackendImpl.cs | 670 +++++++++--------- src/TensorFlowNET.Keras/Engine/Functional.cs | 2 +- src/TensorFlowNET.Keras/Engine/Layer.Apply.cs | 2 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 103 +-- src/TensorFlowNET.Keras/Engine/Sequential.cs | 6 +- .../Layers/Convolution/Conv2DTranspose.cs | 2 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 17 + .../Layers/Rnn/DropOutRNNCellMixin.cs | 80 +++ src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 236 +++--- .../Layers/Rnn/RNNUtils.cs | 59 ++ .../Layers/Rnn/SimpleRNNCell.cs | 92 ++- src/TensorflowNET.Hub/KerasLayer.cs | 2 +- .../Callbacks/EarlystoppingTest.cs | 60 -- .../Layers/LayersTest.cs | 11 + .../Tensorflow.Keras.UnitTest.csproj | 4 + 20 files changed, 883 insertions(+), 557 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs create mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/RNNUtils.cs delete mode 100644 test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs index fcfd694d..d8fdfae5 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs @@ -2,6 +2,9 @@ { public class SimpleRNNArgs : RNNArgs { - + public float Dropout = 0f; + public float RecurrentDropout = 0f; + public int state_size; + public int output_size; } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f7669394..8bcefc1d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -27,5 +27,11 @@ namespace Tensorflow.Keras TF_DataType DType { get; } int count_params(); void adapt(Tensor data, int? batch_size = null, int? steps = null); + + Tensors Call(Tensors inputs, Tensor? mask = null, bool? training = null, Tensors? initial_state = null, Tensors? constants = null); + + StateSizeWrapper state_size { get; } + + int output_size { get; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 6a29f9e5..e60ba6fc 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -200,6 +200,14 @@ namespace Tensorflow.Keras.Layers bool return_sequences = false, bool return_state = false); + public ILayer SimpleRNNCell( + int units, + string activation = "tanh", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros"); + public ILayer Subtract(); } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index d49c8218..2dc70177 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -89,6 +89,8 @@ namespace Tensorflow protected bool built = false; public bool Built => built; + StateSizeWrapper ILayer.state_size => throw new NotImplementedException(); + public RnnCell(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, @@ -174,5 +176,10 @@ namespace Tensorflow { throw new NotImplementedException(); } + + public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 2879fa8e..8fa9dcac 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -19,6 +19,7 @@ using System; using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; namespace Tensorflow.Util { @@ -213,6 +214,17 @@ namespace Tensorflow.Util public static bool is_nested(object obj) { + // Refer to https://www.tensorflow.org/api_docs/python/tf/nest + //if (obj is IList || obj is IDictionary || obj is ITuple) + // return true; + if (obj is IList || obj is IDictionary) + return true; + + if (obj is NDArray || obj is Tensor || obj is string || obj.GetType().IsGenericType + || obj is ISet || obj is ISet || obj is ISet) + return false; + + if (obj.GetType().IsNested) return true; // Check if the object is an IEnumerable if (obj is IEnumerable) { @@ -244,7 +256,13 @@ namespace Tensorflow.Util _flatten_recursive(structure, list); return list; } - + // TODO(Wanglongzhi2001), ITuple must used in .NET standard 2.1, but now is 2.0 + // If you want to flatten a nested tuple, please specify the type of the tuple + //public static List flatten(ITuple structure) + //{ + // var list = FlattenTuple(structure).ToList(); + // return list; + //} public static List flatten(IEnumerable structure) { var list = new List(); @@ -272,9 +290,13 @@ namespace Tensorflow.Util case String str: list.Add(obj); break; - case NDArray nd: + // This case can hold both Tensor and NDArray + case Tensor tensor: list.Add(obj); break; + //case NDArray nd: + // list.Add(obj); + // break; case IEnumerable structure: foreach (var child in structure) _flatten_recursive((T)child, list); @@ -285,28 +307,26 @@ namespace Tensorflow.Util } } - public static List FlattenTupple(object tuple) + private static IEnumerable FlattenTuple(object tuple) { - List items = new List(); - var type = tuple.GetType(); - - if (type.GetInterface("ITuple") == null) - throw new ArgumentException("This is not a tuple!"); + //if (tuple is ITuple t) + //{ + // for (int i = 0; i < t.Length; i++) + // { + // foreach (var item in FlattenTuple(t[i])) + // { + // yield return item; + // } + // } + //} + if(false) + { - foreach (var property in type.GetProperties()) + } + else { - var value = property.GetValue(tuple); - if (property.PropertyType.GetInterface("ITuple") != null) - { - var subItems = FlattenTupple(value); - items.AddRange(subItems); - } - else - { - items.Add((T)value); - } + yield return (T)tuple; } - return items; } //# See the swig file (util.i) for documentation. //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples @@ -494,8 +514,12 @@ namespace Tensorflow.Util throw new ArgumentException("flat_sequence must not be null"); // if not is_sequence(flat_sequence): // raise TypeError("flat_sequence must be a sequence") - - if (!is_sequence(structure)) + if (!is_nested(flat_sequence)) + { + throw new ArrayTypeMismatchException($"Attempted to pack value:\\n {flat_sequence}\\ninto a structure, " + + $"but found incompatible type `{flat_sequence.GetType()}` instead."); + } + if (!is_nested(structure)) { if (len(flat) != 1) throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat)} > 1"); diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index da1d25c9..94aeb0dd 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -614,331 +614,331 @@ namespace Tensorflow.Keras return nest.pack_sequence_as(inputs, inp); } - if (mask != null) - { - var mask_list = tf.unstack(mask); - if (go_backwards) - { - mask_list.Reverse(); - } - - for (int i = 0; i < time_steps; i++) - { - // TODO(Wanglongzhi2001),deal with _get_input_tensor - var inp = _get_input_tensor(i); - var mask_t = mask_list[i]; - // TODO - var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); - - var tiled_mask_t = _expand_mask(mask_t, output); - - Tensors prev_output; - if (successive_outputs == null) - { - prev_output = tf.zeros_like(output); - } - else - { - prev_output = successive_outputs[successive_outputs.Length - 1]; - } - - output = tf.where(tiled_mask_t, output, prev_output); - - //var flat_states = nest.flatten(states); - //var flat_new_states = nest.flatten(newStates); - var flat_states = states.ToList(); - var flat_new_states = newStates.ToList(); - - var tiledMaskT = flat_states - .Select(s => _expand_mask(mask_t, s)) - .ToArray(); - var tuple = Tuple.Create(tiledMaskT); - - List flat_final_states = new List(); - foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) - { - flat_final_states.Add(tf.where(m, s, ps)); - } - - states = (Tensors)nest.pack_sequence_as(states, flat_final_states); - if (return_all_outputs) - { - successive_outputs.Add(output); - successive_states.Add(states); - } - else - { - successive_outputs = new Tensors { output }; - successive_states = new Tensors { states }; - } - - } - last_output = successive_outputs[successive_outputs.Length - 1]; - new_states = successive_states[successive_states.Length - 1]; - outputs = tf.stack(successive_outputs); - - if (zero_output_for_mask) - { - last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); - outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); - } - else // mask is null - { - for (int i = 0; i < time_steps; i++) - { - var inp = _get_input_tensor(i); - var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); - states = newStates; - - if (return_all_outputs) - { - successive_outputs.Add(output); - successive_states.Add(newStates); - } - else - { - successive_outputs = new Tensors { output }; - successive_states = new Tensors { newStates }; - } - } - last_output = successive_outputs[successive_outputs.Length - 1]; - new_states = successive_states[successive_states.Length - 1]; - outputs = tf.stack(successive_outputs); - } - } - } - else // unroll == false - { - var states = initial_states; - // Create input tensor array, if the inputs is nested tensors, then it - // will be flattened first, and tensor array will be created one per - // flattened tensor. - var input_ta = new List(); - for (int i = 0; i < flatted_inptus.Count; i++) - { - input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); - } - - // Get the time(0) input and compute the output for that, the output will - // be used to determine the dtype of output tensor array. Don't read from - // input_ta due to TensorArray clear_after_read default to True. - var inps = new Tensors(); - foreach (var inp in flatted_inptus) - { - inps.Add(inp[0]); - } - var input_time_zero = nest.pack_sequence_as(inputs, inps); - - // output_time_zero is used to determine the cell output shape and its - // dtype. the value is discarded. - (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); - - var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); - var output_ta = new List(); - for (int i = 0; i < output_time_zero.ToList().Count; i++) - { - var Out = output_time_zero.ToList()[i]; - output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); - } - - var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); - - - - Func? masking_fn; - Func? compute_masked_output = null; - if (mask != null) - { - if (go_backwards) - { - mask = tf.reverse(mask, axis: new[] { 0 }); - } - var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); - mask_ta = mask_ta.unstack(mask); - - masking_fn = (time) => - { - return mask_ta.read(time); - }; - - compute_masked_output = (mask_t, flat_out, flat_mask) => - { - var tiled_mask_t = new Tensors(); - foreach (var o in flat_out) - { - tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); - } - - Tensors res = new Tensors(); - foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) - { - res.Add(tf.where(m, o, fm)); - } - return res; - }; - } - // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? - else if (input_length is Tensor) - { - if (go_backwards) - { - var max_len = tf.reduce_max(input_length, axis: 0); - var rev_input_length = tf.subtract(max_len - 1, input_length); - - masking_fn = (time) => - { - return tf.less(rev_input_length, time); - }; - } - else - { - masking_fn = (time) => - { - return tf.greater(input_length, time); - }; - } - - compute_masked_output = (mask_t, flat_out, flat_mask) => - { - var res = new List(); - foreach (var (o, zo) in zip(flat_out, flat_mask)) - { - res.Add(tf.where(mask_t, o, zo)); - } - return res; - }; - } - else - { - masking_fn = null; - } - - - if (masking_fn != null) - { - // Mask for the T output will be base on the output of T - 1. In the - // case T = 0, a zero filled tensor will be used. - var flat_zero_output = new Tensors(); - foreach (var o in nest.flatten(output_time_zero)) - { - flat_zero_output.Add(tf.zeros_like(o)); - } - - - (Tensor, List, Tensors, Tensors) _step(Tensor time, List output_ta_t, Tensors prev_output, Tensors states) - { - /* - RNN step function. - Args: - time: Current timestep value. - output_ta_t: TensorArray. - prev_output: tuple of outputs from time - 1. - *states: List of states. - Returns: - Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` - */ - - var current_input = input_ta.Select(x => x.read(time)).ToList(); - // maybe set shape - // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type - current_input = (List)nest.pack_sequence_as(inputs, current_input); - var mask_t = masking_fn(time); - var (output, new_states) = step_function(current_input, new Tensors { states, constants }); - // mask output - //var flat_output = nest.flatten(output); - var flat_output = output.ToList(); - - var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); - - // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type - var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); - - // mask states - var flat_state = states.ToList(); - var flat_new_state = new_states.ToList(); - - foreach (var (state, new_state) in zip(flat_state, flat_new_state)) - { - if (new_state is Tensor) - { - new_state.set_shape(state.shape); - } - } - - var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); - new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); - - var ta_index_to_write = return_all_outputs ? time : tf.constant(0); - var Output_ta_t = new List(); - // TODO(Wanglongzhi2001),deal with zip output_ta_t - foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) - { - Output_ta_t.Add(ta.write(ta_index_to_write, Out)); - } - - - - //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); - - - return (time + 1, Output_ta_t, flat_new_output, new_states); - - } - Func cond = (time) => (time < time_step_t); - - var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); - new_states = final_outputs.Item4; - output_ta = final_outputs.Item2; - - } - else - { - (Tensor, List, Tensors) _step(Tensor time, List output_ta_t, Tensors states) - { - var current_input = input_ta.Select(x => x.read(time)).ToList(); - // maybe set shape - // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type - current_input = (List)nest.pack_sequence_as(inputs, current_input); - var (output, new_states) = step_function(current_input, new Tensors { states, constants }); - var flat_state = states.ToList(); - var flat_new_state = new_states.ToList(); - foreach (var (state, new_state) in zip(flat_state, flat_new_state)) - { - if (new_state is Tensor) - { - new_state.set_shape(state.shape); - } - } - var flat_output = output.ToList(); - var ta_index_to_write = return_all_outputs ? time : tf.constant(0); - var Output_ta_t = new List(); - foreach (var (ta, out_) in zip(output_ta_t, flat_output)) - { - Output_ta_t.Add(ta.write(ta_index_to_write, out_)); - } - - new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); - return (time + 1, Output_ta_t, new_states); - } - Func cond = (time) => (time < time_step_t); - var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); - new_states = final_outputs.Item3; - output_ta = final_outputs.Item2; - - } - //Tensors outputs = new Tensors(); - foreach (var o in output_ta) - { - outputs.Add(o.stack()); - } - foreach (var o in outputs) - { - last_output.Add(o[-1]); - } - outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); - last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); - + //if (mask != null) + //{ + // var mask_list = tf.unstack(mask); + // if (go_backwards) + // { + // mask_list.Reverse(); + // } + + // for (int i = 0; i < time_steps; i++) + // { + // // TODO(Wanglongzhi2001),deal with _get_input_tensor + // var inp = _get_input_tensor(i); + // var mask_t = mask_list[i]; + // // TODO + // var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); + + // var tiled_mask_t = _expand_mask(mask_t, output); + + // Tensors prev_output; + // if (successive_outputs == null) + // { + // prev_output = tf.zeros_like(output); + // } + // else + // { + // prev_output = successive_outputs[successive_outputs.Length - 1]; + // } + + // output = tf.where(tiled_mask_t, output, prev_output); + + // //var flat_states = nest.flatten(states); + // //var flat_new_states = nest.flatten(newStates); + // var flat_states = states.ToList(); + // var flat_new_states = newStates.ToList(); + + // var tiledMaskT = flat_states + // .Select(s => _expand_mask(mask_t, s)) + // .ToArray(); + // var tuple = Tuple.Create(tiledMaskT); + + // List flat_final_states = new List(); + // foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) + // { + // flat_final_states.Add(tf.where(m, s, ps)); + // } + + // states = (Tensors)nest.pack_sequence_as(states, flat_final_states); + // if (return_all_outputs) + // { + // successive_outputs.Add(output); + // successive_states.Add(states); + // } + // else + // { + // successive_outputs = new Tensors { output }; + // successive_states = new Tensors { states }; + // } + + // } + // last_output = successive_outputs[successive_outputs.Length - 1]; + // new_states = successive_states[successive_states.Length - 1]; + // outputs = tf.stack(successive_outputs); + + // if (zero_output_for_mask) + // { + // last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); + // outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); + // } + // else // mask is null + // { + // for (int i = 0; i < time_steps; i++) + // { + // var inp = _get_input_tensor(i); + // var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); + // states = newStates; + + // if (return_all_outputs) + // { + // successive_outputs.Add(output); + // successive_states.Add(newStates); + // } + // else + // { + // successive_outputs = new Tensors { output }; + // successive_states = new Tensors { newStates }; + // } + // } + // last_output = successive_outputs[successive_outputs.Length - 1]; + // new_states = successive_states[successive_states.Length - 1]; + // outputs = tf.stack(successive_outputs); + // } + //} } + //else // unroll == false + //{ + // var states = initial_states; + // // Create input tensor array, if the inputs is nested tensors, then it + // // will be flattened first, and tensor array will be created one per + // // flattened tensor. + // var input_ta = new List(); + // for (int i = 0; i < flatted_inptus.Count; i++) + // { + // input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); + // } + + // // Get the time(0) input and compute the output for that, the output will + // // be used to determine the dtype of output tensor array. Don't read from + // // input_ta due to TensorArray clear_after_read default to True. + // var inps = new Tensors(); + // foreach (var inp in flatted_inptus) + // { + // inps.Add(inp[0]); + // } + // var input_time_zero = nest.pack_sequence_as(inputs, inps); + + // // output_time_zero is used to determine the cell output shape and its + // // dtype. the value is discarded. + // (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); + + // var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); + // var output_ta = new List(); + // for (int i = 0; i < output_time_zero.ToList().Count; i++) + // { + // var Out = output_time_zero.ToList()[i]; + // output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); + // } + + // var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); + + + + // Func? masking_fn; + // Func? compute_masked_output = null; + // if (mask != null) + // { + // if (go_backwards) + // { + // mask = tf.reverse(mask, axis: new[] { 0 }); + // } + // var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); + // mask_ta = mask_ta.unstack(mask); + + // masking_fn = (time) => + // { + // return mask_ta.read(time); + // }; + + // compute_masked_output = (mask_t, flat_out, flat_mask) => + // { + // var tiled_mask_t = new Tensors(); + // foreach (var o in flat_out) + // { + // tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); + // } + + // Tensors res = new Tensors(); + // foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) + // { + // res.Add(tf.where(m, o, fm)); + // } + // return res; + // }; + // } + // // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? + // else if (input_length is Tensor) + // { + // if (go_backwards) + // { + // var max_len = tf.reduce_max(input_length, axis: 0); + // var rev_input_length = tf.subtract(max_len - 1, input_length); + + // masking_fn = (time) => + // { + // return tf.less(rev_input_length, time); + // }; + // } + // else + // { + // masking_fn = (time) => + // { + // return tf.greater(input_length, time); + // }; + // } + + // compute_masked_output = (mask_t, flat_out, flat_mask) => + // { + // var res = new List(); + // foreach (var (o, zo) in zip(flat_out, flat_mask)) + // { + // res.Add(tf.where(mask_t, o, zo)); + // } + // return res; + // }; + // } + // else + // { + // masking_fn = null; + // } + + + // if (masking_fn != null) + // { + // // Mask for the T output will be base on the output of T - 1. In the + // // case T = 0, a zero filled tensor will be used. + // var flat_zero_output = new Tensors(); + // foreach (var o in nest.flatten(output_time_zero)) + // { + // flat_zero_output.Add(tf.zeros_like(o)); + // } + + + // (Tensor, List, Tensors, Tensors) _step(Tensor time, List output_ta_t, Tensors prev_output, Tensors states) + // { + // /* + // RNN step function. + // Args: + // time: Current timestep value. + // output_ta_t: TensorArray. + // prev_output: tuple of outputs from time - 1. + // *states: List of states. + // Returns: + // Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` + // */ + + // var current_input = input_ta.Select(x => x.read(time)).ToList(); + // // maybe set shape + // // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + // current_input = (List)nest.pack_sequence_as(inputs, current_input); + // var mask_t = masking_fn(time); + // var (output, new_states) = step_function(current_input, new Tensors { states, constants }); + // // mask output + // //var flat_output = nest.flatten(output); + // var flat_output = output.ToList(); + + // var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); + + // // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type + // var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); + + // // mask states + // var flat_state = states.ToList(); + // var flat_new_state = new_states.ToList(); + + // foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + // { + // if (new_state is Tensor) + // { + // new_state.set_shape(state.shape); + // } + // } + + // var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); + // new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); + + // var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + // var Output_ta_t = new List(); + // // TODO(Wanglongzhi2001),deal with zip output_ta_t + // foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) + // { + // Output_ta_t.Add(ta.write(ta_index_to_write, Out)); + // } + + + + // //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); + + + // return (time + 1, Output_ta_t, flat_new_output, new_states); + + // } + // Func cond = (time) => (time < time_step_t); + + // var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); + // new_states = final_outputs.Item4; + // output_ta = final_outputs.Item2; + + // } + // else + // { + // (Tensor, List, Tensors) _step(Tensor time, List output_ta_t, Tensors states) + // { + // var current_input = input_ta.Select(x => x.read(time)).ToList(); + // // maybe set shape + // // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + // current_input = (List)nest.pack_sequence_as(inputs, current_input); + // var (output, new_states) = step_function(current_input, new Tensors { states, constants }); + // var flat_state = states.ToList(); + // var flat_new_state = new_states.ToList(); + // foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + // { + // if (new_state is Tensor) + // { + // new_state.set_shape(state.shape); + // } + // } + // var flat_output = output.ToList(); + // var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + // var Output_ta_t = new List(); + // foreach (var (ta, out_) in zip(output_ta_t, flat_output)) + // { + // Output_ta_t.Add(ta.write(ta_index_to_write, out_)); + // } + + // new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); + // return (time + 1, Output_ta_t, new_states); + // } + // Func cond = (time) => (time < time_step_t); + // var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); + // new_states = final_outputs.Item3; + // output_ta = final_outputs.Item2; + + // } + // //Tensors outputs = new Tensors(); + // foreach (var o in output_ta) + // { + // outputs.Add(o.stack()); + // } + // foreach (var o in outputs) + // { + // last_output.Add(o[-1]); + // } + // outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); + // last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); + + //} Func set_shape; set_shape = (output_) => @@ -968,5 +968,27 @@ namespace Tensorflow.Keras return (last_output, Outputs, new_states); } + + // Multiplies 2 tensors (and/or variables) and returns a tensor. + // This operation corresponds to `numpy.dot(a, b, out=None)`. + public Tensor Dot(Tensor x, Tensor y) + { + //if (x.ndim != 1 && (x.ndim > 2 || y.ndim > 2)) + //{ + // var x_shape = new List(); + // foreach (var (i,s) in zip(x.shape.as_int_list(), tf.unstack(tf.shape(x)))) + // { + // if (i != 0) + // { + // x_shape.append(i); + // } + // else + // { + // x_shape.append(s); + // } + // } + //} + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index e768bd0b..660856b6 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -325,7 +325,7 @@ namespace Tensorflow.Keras.Engine nodes_in_decreasing_depth.append(node); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var tensor_dict = new Dictionary>(); // map input values diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs index c0430458..f40cdec7 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs @@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine if (!built) MaybeBuild(inputs); - var outputs = Call(inputs, state: state, training: training); + var outputs = Call(inputs, initial_state: state, training: training); // memory leak // _set_connectivity_metadata_(inputs, outputs); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 4216c725..7d50f83a 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -254,6 +254,10 @@ namespace Tensorflow.Keras.Engine /// public Func? ReplacedCall { get; set; } = null; + public StateSizeWrapper state_size => throw new NotImplementedException(); + + public int output_size => throw new NotImplementedException(); + public Layer(LayerArgs args) { Initialize(args); @@ -434,56 +438,61 @@ namespace Tensorflow.Keras.Engine public override void SetAttr(string name, object value) { - // TODO(Rinne): deal with "_self_setattr_tracking". + //// TODO(Rinne): deal with "_self_setattr_tracking". - value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); + //value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); - foreach(var val in nest.flatten(value)) - { - if(val is Metric) - { - // TODO(Rinne): deal with metrics. - } - } - - // TODO(Rinne): deal with "_auto_track_sub_layers". - - foreach(var val in nest.flatten(value)) - { - if(val is not IVariableV1 variable) - { - continue; - } - if (variable.Trainable) - { - if (_trainable_weights.Contains(variable)) - { - continue; - } - _trainable_weights.Add(variable); - } - else - { - if (_non_trainable_weights.Contains(variable)) - { - continue; - } - _non_trainable_weights.Add(variable); - } - keras.backend.track_variable(variable); - } + //foreach(var val in nest.flatten(value)) + //{ + // if(val is Metric) + // { + // // TODO(Rinne): deal with metrics. + // } + //} + + //// TODO(Rinne): deal with "_auto_track_sub_layers". + + //foreach(var val in nest.flatten(value)) + //{ + // if(val is not IVariableV1 variable) + // { + // continue; + // } + // if (variable.Trainable) + // { + // if (_trainable_weights.Contains(variable)) + // { + // continue; + // } + // _trainable_weights.Add(variable); + // } + // else + // { + // if (_non_trainable_weights.Contains(variable)) + // { + // continue; + // } + // _non_trainable_weights.Add(variable); + // } + // keras.backend.track_variable(variable); + //} + + //// Directly use the implementation of `Trackable`. + //var t = this.GetType(); + //var field_info = t.GetField(name); + //if (field_info is not null) + //{ + // field_info.SetValue(this, value); + //} + //else + //{ + // CustomizedFields[name] = value; + //} + } - // Directly use the implementation of `Trackable`. - var t = this.GetType(); - var field_info = t.GetField(name); - if (field_info is not null) - { - field_info.SetValue(this, value); - } - else - { - CustomizedFields[name] = value; - } + Tensors ILayer.Call(Tensors inputs, Tensor mask, bool? training, Tensors initial_state, Tensors constants) + { + throw new NotImplementedException(); } } } diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 27874751..bb70e67e 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Engine } } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (!_has_explicit_input_shape) { @@ -154,10 +154,10 @@ namespace Tensorflow.Keras.Engine { if (!built) _init_graph_network(this.inputs, outputs); - return base.Call(inputs, state, training); + return base.Call(inputs, initial_state, training); } - return base.Call(inputs, state, training); + return base.Call(inputs, initial_state, training); } void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs index bbd49acd..217dd28f 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -83,7 +83,7 @@ namespace Tensorflow.Keras.Layers _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var inputs_shape = array_ops.shape(inputs); var batch_size = inputs_shape[0]; diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 3b095bc2..02e9d995 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -709,6 +709,23 @@ namespace Tensorflow.Keras.Layers ReturnState = return_state }); + public ILayer SimpleRNNCell( + int units, + string activation = "tanh", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros") + => new SimpleRNNCell(new SimpleRNNArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName(activation), + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + } + ); + /// /// Long Short-Term Memory layer - Hochreiter 1997. /// diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs new file mode 100644 index 00000000..fcf9b596 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs @@ -0,0 +1,80 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.Keras.Engine; + + + +namespace Tensorflow.Keras.Layers.Rnn +{ + public class DropoutRNNCellMixin + { + public float dropout; + public float recurrent_dropout; + // Get the dropout mask for RNN cell's input. + public Tensors get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + { + + return _generate_dropout_mask( + tf.ones_like(input), + dropout, + training, + count); + } + + // Get the recurrent dropout mask for RNN cell. + public Tensors get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + { + return _generate_dropout_mask( + tf.ones_like(input), + recurrent_dropout, + training, + count); + } + + public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1) + { + return _generate_dropout_mask( + tf.ones_like(input), + dropout, + training, + count); + } + + public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1) + { + return _generate_dropout_mask( + tf.ones_like(input), + recurrent_dropout, + training, + count); + } + + public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1) + { + Tensors dropped_inputs() + { + DropoutArgs args = new DropoutArgs(); + args.Rate = rate; + var DropoutLayer = new Dropout(args); + var mask = DropoutLayer.Apply(ones, training: training); + return mask; + } + + if (count > 1) + { + Tensors results = new Tensors(); + for (int i = 0; i < count; i++) + { + results.Add(dropped_inputs()); + } + return results; + } + + return dropped_inputs(); + } + } + + +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 7bd4047a..a26743e6 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers.Rnn private RNNArgs args; private object input_spec = null; // or NoneValue?? private object state_spec = null; - private object _states = null; + private Tensors _states = null; private object constants_spec = null; private int _num_constants = 0; protected IVariableV1 kernel; @@ -44,19 +44,15 @@ namespace Tensorflow.Keras.Layers.Rnn cell = args.Cell.AsT1; } - - - - Type type = cell.GetType(); - MethodInfo methodInfo = type.GetMethod("Call"); - if (methodInfo == null) + MethodInfo callMethodInfo = type.GetMethod("Call"); + if (callMethodInfo == null) { throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. "); } - PropertyInfo propertyInfo = type.GetProperty("state_size"); - if (propertyInfo == null) + PropertyInfo state_size_info = type.GetProperty("state_size"); + if (state_size_info == null) { throw new ValueError(@"The RNN cell should have a `state_size` attribute"); } @@ -80,7 +76,7 @@ namespace Tensorflow.Keras.Layers.Rnn // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape - public object States + public Tensors States { get { @@ -106,7 +102,6 @@ namespace Tensorflow.Keras.Layers.Rnn // state_size is a array of ints or a positive integer var state_size = cell.state_size; - // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor Func _get_output_shape; _get_output_shape = (flat_output_size) => @@ -132,8 +127,10 @@ namespace Tensorflow.Keras.Layers.Rnn return output_shape; }; + Type type = cell.GetType(); + PropertyInfo output_size_info = type.GetProperty("output_size"); Shape output_shape; - if (cell.output_size != 0) + if (output_size_info != null) { output_shape = nest.map_structure(_get_output_shape, cell.output_size); // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 @@ -160,6 +157,7 @@ namespace Tensorflow.Keras.Layers.Rnn { return output_shape; } + } private Tensors compute_mask(Tensors inputs, Tensors mask) @@ -184,8 +182,6 @@ namespace Tensorflow.Keras.Layers.Rnn { return output_mask; } - - } public override void build(KerasShapesWrapper input_shape) @@ -247,14 +243,18 @@ namespace Tensorflow.Keras.Layers.Rnn protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); - //bool is_ragged_input = row_length != null; - //_validate_args_if_ragged(is_ragged_input, mask); - var (inputs_processed, initial_state_processed, constants_processed) = _process_inputs(inputs, initial_state, constants); + // 暂时先不接受ragged tensor + int? row_length = null; + bool is_ragged_input = false; + _validate_args_if_ragged(is_ragged_input, mask); + + (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); _maybe_reset_cell_dropout_mask(cell); if (cell is StackedRNNCells) { - foreach (var cell in ((StackedRNNCells)cell).Cells) + var stack_cell = cell as StackedRNNCells; + foreach (var cell in stack_cell.Cells) { _maybe_reset_cell_dropout_mask(cell); } @@ -263,17 +263,16 @@ namespace Tensorflow.Keras.Layers.Rnn if (mask != null) { // Time step masks must be the same for each input. - //mask = nest.flatten(mask)[0]; - mask = mask[0]; + mask = nest.flatten(mask)[0]; } - Shape input_shape; - if (nest.is_nested(initial_state_processed)) + if (nest.is_nested(inputs)) { // In the case of nested input, use the first element for shape check // input_shape = nest.flatten(inputs)[0].shape; - input_shape = inputs[0].shape; + // TODO(Wanglongzhi2001) + input_shape = nest.flatten(inputs)[0].shape; } else { @@ -322,6 +321,7 @@ namespace Tensorflow.Keras.Layers.Rnn // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) states = states.Length == 1 ? states[0] : states; var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); + // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors? if (!nest.is_nested(new_states)) { return (output, new Tensors { new_states }); @@ -351,7 +351,7 @@ namespace Tensorflow.Keras.Layers.Rnn go_backwards: args.GoBackwards, mask: mask, unroll: args.Unroll, - input_length: row_length != null ? row_length : new Tensor(timesteps), + input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps), time_major: args.TimeMajor, zero_output_for_mask: args.ZeroOutputForMask, return_all_outputs: args.ReturnSequences); @@ -387,24 +387,9 @@ namespace Tensorflow.Keras.Layers.Rnn } } - private (Tensors, Tensors, Tensors) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants) + private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants) { - bool IsSequence(object obj) - { - // Check if the object is an IEnumerable - if (obj is IEnumerable) - { - // If it is, check if it is a tuple - if (!(obj is Tuple)) - { - return true; - } - } - // If it is not, return false - return false; - } - - if (IsSequence(input)) + if (nest.is_sequence(input)) { if (_num_constants != 0) { @@ -413,6 +398,7 @@ namespace Tensorflow.Keras.Layers.Rnn else { initial_state = inputs[new Slice(1, len(inputs) - _num_constants)]; + constants = inputs[new Slice(len(inputs) - _num_constants, len(inputs))]; } if (len(initial_state) == 0) initial_state = null; @@ -421,32 +407,63 @@ namespace Tensorflow.Keras.Layers.Rnn if (args.Stateful) { - throw new NotImplementedException("argument stateful has not been implemented!"); + if (initial_state != null) + { + var tmp = new Tensor[] { }; + foreach (var s in nest.flatten(States)) + { + tmp.add(tf.math.count_nonzero((Tensor)s)); + } + var non_zero_count = tf.add_n(tmp); + //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); + if((int)non_zero_count.numpy() > 0) + { + initial_state = States; + } + } + else + { + initial_state = States; + } } + else if(initial_state != null) + { + initial_state = get_initial_state(inputs); + } - return (inputs, initial_state, constants); + if (initial_state.Length != States.Length) + { + throw new ValueError( + $"Layer {this} expects {States.Length} state(s), " + + $"but it received {initial_state.Length} " + + $"initial state(s). Input received: {inputs}"); + } + return (inputs, initial_state, constants); } private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) { - if (is_ragged_input) + if (!is_ragged_input) { - if (args.Unroll) - { - throw new ValueError("The input received contains RaggedTensors and does " + - "not support unrolling. Disable unrolling by passing " + - "`unroll=False` in the RNN Layer constructor."); - } - if (mask != null) - { - throw new ValueError($"The mask that was passed in was {mask}, which " + - "cannot be applied to RaggedTensor inputs. Please " + - "make sure that there is no mask injected by upstream " + - "layers."); - } + return; + } + + if (args.Unroll) + { + throw new ValueError("The input received contains RaggedTensors and does " + + "not support unrolling. Disable unrolling by passing " + + "`unroll=False` in the RNN Layer constructor."); + } + if (mask != null) + { + throw new ValueError($"The mask that was passed in was {mask}, which " + + "cannot be applied to RaggedTensor inputs. Please " + + "make sure that there is no mask injected by upstream " + + "layers."); } + } void _maybe_reset_cell_dropout_mask(ILayer cell) @@ -489,46 +506,77 @@ namespace Tensorflow.Keras.Layers.Rnn { throw new NotImplementedException(); } - public RNN New(LayerRnnCell cell, - bool return_sequences = false, - bool return_state = false, - bool go_backwards = false, - bool stateful = false, - bool unroll = false, - bool time_major = false) - => new RNN(new RNNArgs - { - Cell = cell, - ReturnSequences = return_sequences, - ReturnState = return_state, - GoBackwards = go_backwards, - Stateful = stateful, - Unroll = unroll, - TimeMajor = time_major - }); - public RNN New(IList cell, - bool return_sequences = false, - bool return_state = false, - bool go_backwards = false, - bool stateful = false, - bool unroll = false, - bool time_major = false) - => new RNN(new RNNArgs - { - Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }), - ReturnSequences = return_sequences, - ReturnState = return_state, - GoBackwards = go_backwards, - Stateful = stateful, - Unroll = unroll, - TimeMajor = time_major - }); + // 好像不能cell不能传接口类型 + //public RNN New(IRnnArgCell cell, + // bool return_sequences = false, + // bool return_state = false, + // bool go_backwards = false, + // bool stateful = false, + // bool unroll = false, + // bool time_major = false) + // => new RNN(new RNNArgs + // { + // Cell = cell, + // ReturnSequences = return_sequences, + // ReturnState = return_state, + // GoBackwards = go_backwards, + // Stateful = stateful, + // Unroll = unroll, + // TimeMajor = time_major + // }); + + //public RNN New(List cell, + // bool return_sequences = false, + // bool return_state = false, + // bool go_backwards = false, + // bool stateful = false, + // bool unroll = false, + // bool time_major = false) + // => new RNN(new RNNArgs + // { + // Cell = cell, + // ReturnSequences = return_sequences, + // ReturnState = return_state, + // GoBackwards = go_backwards, + // Stateful = stateful, + // Unroll = unroll, + // TimeMajor = time_major + // }); + + + protected Tensors get_initial_state(Tensor inputs) + { + Type type = cell.GetType(); + MethodInfo MethodInfo = type.GetMethod("get_initial_state"); + if (nest.is_nested(inputs)) + { + // The input are nested sequences. Use the first element in the seq + // to get batch size and dtype. + inputs = nest.flatten(inputs)[0]; + } - protected Tensor get_initial_state(Tensor inputs) - { - return _generate_zero_filled_state_for_cell(null, null); + var input_shape = tf.shape(inputs); + var batch_size = args.TimeMajor ? input_shape[1] : input_shape[0]; + var dtype = inputs.dtype; + Tensor init_state; + if (MethodInfo != null) + { + init_state = (Tensor)MethodInfo.Invoke(cell, new object[] { null, batch_size, dtype }); + } + else + { + init_state = RNNUtils.generate_zero_filled_state(batch_size, cell.state_size, dtype); + } + + //if (!nest.is_nested(init_state)) + //{ + // init_state = new List { init_state}; + //} + return new List { init_state }; + + //return _generate_zero_filled_state_for_cell(null, null); } Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size) diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNNUtils.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNNUtils.cs new file mode 100644 index 00000000..f516f765 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNNUtils.cs @@ -0,0 +1,59 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; +using OneOf; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.Layers.Rnn +{ + public class RNNUtils + { + public static Tensor generate_zero_filled_state(Tensor batch_size_tensor, StateSizeWrapper state_size, TF_DataType dtype = TF_DataType.TF_FLOAT) + { + if (batch_size_tensor == null || dtype == null) + { + throw new ValueError( + "batch_size and dtype cannot be None while constructing initial " + + $"state. Received: batch_size={batch_size_tensor}, dtype={dtype}"); + } + + Func create_zeros; + create_zeros = (StateSizeWrapper unnested_state_size) => + { + var flat_dims = unnested_state_size.state_size; + //if (unnested_state_size is int[]) + //{ + // flat_dims = new Shape(unnested_state_size.AsT0).as_int_list(); + //} + //else if (unnested_state_size.IsT1) + //{ + // flat_dims = new Shape(unnested_state_size.AsT1).as_int_list(); + //} + var init_state_size = batch_size_tensor.ToArray().concat(flat_dims); + return tf.zeros(init_state_size, dtype: dtype); + }; + + //if (nest.is_nested(state_size)) + //{ + // return nest.map_structure(create_zeros, state_size); + //} + //else + //{ + // return create_zeros(state_size); + //} + return create_zeros(state_size); + + } + + public static Tensor generate_zero_filled_state_for_cell(SimpleRNNCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype) + { + if (inputs != null) + { + batch_size = tf.shape(inputs)[0]; + dtype = inputs.dtype; + } + return generate_zero_filled_state(batch_size, cell.state_size, dtype); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 86985d7e..2c89d2e6 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -4,6 +4,7 @@ using System.Text; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Util; namespace Tensorflow.Keras.Layers.Rnn { @@ -13,10 +14,23 @@ namespace Tensorflow.Keras.Layers.Rnn IVariableV1 kernel; IVariableV1 recurrent_kernel; IVariableV1 bias; - + DropoutRNNCellMixin DRCMixin; public SimpleRNNCell(SimpleRNNArgs args) : base(args) { this.args = args; + if (args.Units <= 0) + { + throw new ValueError( + $"units must be a positive integer, got {args.Units}"); + } + this.args.Dropout = Math.Min(1f, Math.Max(0f, this.args.Dropout)); + this.args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this.args.RecurrentDropout)); + this.args.state_size = this.args.Units; + this.args.output_size = this.args.Units; + + DRCMixin = new DropoutRNNCellMixin(); + DRCMixin.dropout = this.args.Dropout; + DRCMixin.recurrent_dropout = this.args.RecurrentDropout; } public override void build(KerasShapesWrapper input_shape) @@ -44,7 +58,81 @@ namespace Tensorflow.Keras.Layers.Rnn protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - return base.Call(inputs, initial_state, training); + Console.WriteLine($"shape of input: {inputs.shape}"); + Tensor states = initial_state[0]; + Console.WriteLine($"shape of initial_state: {states.shape}"); + + var prev_output = nest.is_nested(states) ? states[0] : states; + var dp_mask = DRCMixin.get_dropout_maskcell_for_cell(inputs, training.Value); + var rec_dp_mask = DRCMixin.get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); + + Tensor h; + var ranks = inputs.rank; + //if (dp_mask != null) + if(false) + { + if (ranks > 2) + { + h = tf.linalg.tensordot(tf.multiply(inputs, dp_mask), kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); + } + else + { + h = math_ops.matmul(tf.multiply(inputs, dp_mask), kernel.AsTensor()); + } + } + else + { + if (ranks > 2) + { + h = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); + } + else + { + h = math_ops.matmul(inputs, kernel.AsTensor()); + } + } + + if (bias != null) + { + h = tf.nn.bias_add(h, bias); + } + + if (rec_dp_mask != null) + { + prev_output = tf.multiply(prev_output, rec_dp_mask); + } + + ranks = prev_output.rank; + Console.WriteLine($"shape of h: {h.shape}"); + + Tensor output; + if (ranks > 2) + { + var tmp = tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); + output = h + tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } })[0]; + } + else + { + output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor())[0]; + + } + Console.WriteLine($"shape of output: {output.shape}"); + + if (args.Activation != null) + { + output = args.Activation.Apply(output); + } + if (nest.is_nested(states)) + { + return (output, new Tensors { output }); + } + return (output, output); + } + + + public Tensor get_initial_state(Tensors inputs, Tensor batch_size, TF_DataType dtype) + { + return RNNUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); } } } diff --git a/src/TensorflowNET.Hub/KerasLayer.cs b/src/TensorflowNET.Hub/KerasLayer.cs index b9ca949b..6a2ecb4c 100644 --- a/src/TensorflowNET.Hub/KerasLayer.cs +++ b/src/TensorflowNET.Hub/KerasLayer.cs @@ -89,7 +89,7 @@ namespace Tensorflow.Hub } } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { _check_trainability(); diff --git a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs deleted file mode 100644 index ac5ba15e..00000000 --- a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs +++ /dev/null @@ -1,60 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System.Collections.Generic; -using Tensorflow.Keras.Callbacks; -using Tensorflow.Keras.Engine; -using static Tensorflow.KerasApi; - - -namespace Tensorflow.Keras.UnitTest.Callbacks -{ - [TestClass] - public class EarlystoppingTest - { - [TestMethod] - // Because loading the weight variable into the model has not yet been implemented, - // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. - public void Earlystopping() - { - var layers = keras.layers; - var model = keras.Sequential(new List - { - layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)), - layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), - layers.MaxPooling2D(), - layers.Flatten(), - layers.Dense(128, activation: keras.activations.Relu), - layers.Dense(10) - }); - - - model.summary(); - - model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), - loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), - metrics: new[] { "acc" }); - - var num_epochs = 3; - var batch_size = 8; - - var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); - x_train = x_train / 255.0f; - // define a CallbackParams first, the parameters you pass al least contain Model and Epochs. - CallbackParams callback_parameters = new CallbackParams - { - Model = model, - Epochs = num_epochs, - }; - // define your earlystop - ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); - // define a callbcaklist, then add the earlystopping to it. - var callbacks = new List(); - callbacks.add(earlystop); - - model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks); - } - - } - - -} - diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index 3de33746..1e2f894b 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -144,6 +144,17 @@ namespace Tensorflow.Keras.UnitTest.Layers Assert.AreEqual(expected_output, actual_output); } + [TestMethod] + public void SimpleRNNCell() + { + var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; + var x = tf.random.normal(new Shape(4, 100)); + var cell = keras.layers.SimpleRNNCell(64); + var (y, h1) = cell.Apply(inputs:x, state:h0); + Assert.AreEqual((4, 64), y.shape); + Assert.AreEqual((4, 64), h1[0].shape); + } + [TestMethod, Ignore("WIP")] public void SimpleRNN() { diff --git a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj index d744c336..db7d5892 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj @@ -67,4 +67,8 @@ + + + +