Add feature(not completed):add SimpleRNNCell, StackedRNNCell, RNN and test.tags/v0.110.0-LSTM-Model
| @@ -12,9 +12,14 @@ namespace Tensorflow.Common.Types | |||||
| /// create a single-dim generalized Tensor shape. | /// create a single-dim generalized Tensor shape. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="dim"></param> | /// <param name="dim"></param> | ||||
| public GeneralizedTensorShape(int dim) | |||||
| public GeneralizedTensorShape(int dim, int size = 1) | |||||
| { | { | ||||
| Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; | |||||
| var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; | |||||
| Shapes = Enumerable.Repeat(elem, size).ToArray(); | |||||
| //Shapes = new TensorShapeConfig[size]; | |||||
| //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } }); | |||||
| //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } }); | |||||
| ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; | |||||
| } | } | ||||
| public GeneralizedTensorShape(Shape shape) | public GeneralizedTensorShape(Shape shape) | ||||
| @@ -113,6 +118,11 @@ namespace Tensorflow.Common.Types | |||||
| return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s))); | return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s))); | ||||
| } | } | ||||
| } | } | ||||
| public static implicit operator GeneralizedTensorShape(int dims) | |||||
| => new GeneralizedTensorShape(dims); | |||||
| public IEnumerator<long?[]> GetEnumerator() | public IEnumerator<long?[]> GetEnumerator() | ||||
| { | { | ||||
| @@ -10,6 +10,9 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
| [JsonProperty("cell")] | [JsonProperty("cell")] | ||||
| // TODO: the cell should be serialized with `serialize_keras_object`. | // TODO: the cell should be serialized with `serialize_keras_object`. | ||||
| public IRnnCell Cell { get; set; } = null; | public IRnnCell Cell { get; set; } = null; | ||||
| [JsonProperty("cells")] | |||||
| public IList<IRnnCell> Cells { get; set; } = null; | |||||
| [JsonProperty("return_sequences")] | [JsonProperty("return_sequences")] | ||||
| public bool ReturnSequences { get; set; } = false; | public bool ReturnSequences { get; set; } = false; | ||||
| [JsonProperty("return_state")] | [JsonProperty("return_state")] | ||||
| @@ -1,10 +1,11 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.Layers.Rnn; | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | namespace Tensorflow.Keras.ArgsDefinition.Rnn | ||||
| { | { | ||||
| public class StackedRNNCellsArgs : LayerArgs | public class StackedRNNCellsArgs : LayerArgs | ||||
| { | { | ||||
| public IList<RnnCell> Cells { get; set; } | |||||
| public IList<IRnnCell> Cells { get; set; } | |||||
| public Dictionary<string, object> Kwargs { get; set; } = null; | public Dictionary<string, object> Kwargs { get; set; } = null; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using Tensorflow.Keras.Layers.Rnn; | |||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | ||||
| @@ -192,6 +193,19 @@ namespace Tensorflow.Keras.Layers | |||||
| float offset = 0, | float offset = 0, | ||||
| Shape input_shape = null); | Shape input_shape = null); | ||||
| public IRnnCell SimpleRNNCell( | |||||
| int units, | |||||
| string activation = "tanh", | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros", | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f); | |||||
| public IRnnCell StackedRNNCells( | |||||
| IEnumerable<IRnnCell> cells); | |||||
| public ILayer SimpleRNN(int units, | public ILayer SimpleRNN(int units, | ||||
| string activation = "tanh", | string activation = "tanh", | ||||
| string kernel_initializer = "glorot_uniform", | string kernel_initializer = "glorot_uniform", | ||||
| @@ -200,6 +214,26 @@ namespace Tensorflow.Keras.Layers | |||||
| bool return_sequences = false, | bool return_sequences = false, | ||||
| bool return_state = false); | bool return_state = false); | ||||
| public ILayer RNN( | |||||
| IRnnCell cell, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false | |||||
| ); | |||||
| public ILayer RNN( | |||||
| IEnumerable<IRnnCell> cell, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false | |||||
| ); | |||||
| public ILayer Subtract(); | public ILayer Subtract(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -109,7 +109,19 @@ namespace Tensorflow.Operations | |||||
| return ta; | return ta; | ||||
| });*/ | });*/ | ||||
| throw new NotImplementedException(""); | |||||
| //if (indices is EagerTensor) | |||||
| //{ | |||||
| // indices = indices as EagerTensor; | |||||
| // indices = indices.numpy(); | |||||
| //} | |||||
| //foreach (var (index, val) in zip(indices.ToArray<int>(), array_ops.unstack(value))) | |||||
| //{ | |||||
| // this.write(index, val); | |||||
| //} | |||||
| //return base; | |||||
| //throw new NotImplementedException(""); | |||||
| return this; | |||||
| } | } | ||||
| public void _merge_element_shape(Shape shape) | public void _merge_element_shape(Shape shape) | ||||
| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| @@ -146,7 +147,9 @@ namespace Tensorflow.Operations | |||||
| return ta; | return ta; | ||||
| });*/ | });*/ | ||||
| throw new NotImplementedException(""); | |||||
| //throw new NotImplementedException(""); | |||||
| return this; | |||||
| } | } | ||||
| public void _merge_element_shape(Shape shape) | public void _merge_element_shape(Shape shape) | ||||
| @@ -510,7 +510,7 @@ namespace Tensorflow.Keras | |||||
| } | } | ||||
| } | } | ||||
| // tf.where needs its condition tensor to be the same shape as its two | // tf.where needs its condition tensor to be the same shape as its two | ||||
| // result tensors, but in our case the condition (mask) tensor is | // result tensors, but in our case the condition (mask) tensor is | ||||
| // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. | // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. | ||||
| @@ -535,7 +535,7 @@ namespace Tensorflow.Keras | |||||
| { | { | ||||
| mask_t = tf.expand_dims(mask_t, -1); | mask_t = tf.expand_dims(mask_t, -1); | ||||
| } | } | ||||
| var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank)); | |||||
| var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().Skip(fixed_dim).ToArray()); | |||||
| return tf.tile(mask_t, multiples); | return tf.tile(mask_t, multiples); | ||||
| } | } | ||||
| @@ -570,9 +570,6 @@ namespace Tensorflow.Keras | |||||
| // individually. The result of this will be a tuple of lists, each of | // individually. The result of this will be a tuple of lists, each of | ||||
| // the item in tuple is list of the tensor with shape (batch, feature) | // the item in tuple is list of the tensor with shape (batch, feature) | ||||
| Tensors _process_single_input_t(Tensor input_t) | Tensors _process_single_input_t(Tensor input_t) | ||||
| { | { | ||||
| var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim | var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim | ||||
| @@ -609,7 +606,7 @@ namespace Tensorflow.Keras | |||||
| var mask_list = tf.unstack(mask); | var mask_list = tf.unstack(mask); | ||||
| if (go_backwards) | if (go_backwards) | ||||
| { | { | ||||
| mask_list.Reverse(); | |||||
| mask_list.Reverse().ToArray(); | |||||
| } | } | ||||
| for (int i = 0; i < time_steps; i++) | for (int i = 0; i < time_steps; i++) | ||||
| @@ -629,9 +626,10 @@ namespace Tensorflow.Keras | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| prev_output = successive_outputs[successive_outputs.Length - 1]; | |||||
| prev_output = successive_outputs.Last(); | |||||
| } | } | ||||
| // output could be a tensor | |||||
| output = tf.where(tiled_mask_t, output, prev_output); | output = tf.where(tiled_mask_t, output, prev_output); | ||||
| var flat_states = Nest.Flatten(states).ToList(); | var flat_states = Nest.Flatten(states).ToList(); | ||||
| @@ -661,13 +659,13 @@ namespace Tensorflow.Keras | |||||
| } | } | ||||
| } | } | ||||
| last_output = successive_outputs[successive_outputs.Length - 1]; | |||||
| new_states = successive_states[successive_states.Length - 1]; | |||||
| last_output = successive_outputs.Last(); | |||||
| new_states = successive_states.Last(); | |||||
| outputs = tf.stack(successive_outputs); | outputs = tf.stack(successive_outputs); | ||||
| if (zero_output_for_mask) | if (zero_output_for_mask) | ||||
| { | { | ||||
| last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); | |||||
| last_output = tf.where(_expand_mask(mask_list.Last(), last_output), last_output, tf.zeros_like(last_output)); | |||||
| outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); | outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); | ||||
| } | } | ||||
| else // mask is null | else // mask is null | ||||
| @@ -689,8 +687,8 @@ namespace Tensorflow.Keras | |||||
| successive_states = new Tensors { newStates }; | successive_states = new Tensors { newStates }; | ||||
| } | } | ||||
| } | } | ||||
| last_output = successive_outputs[successive_outputs.Length - 1]; | |||||
| new_states = successive_states[successive_states.Length - 1]; | |||||
| last_output = successive_outputs.Last(); | |||||
| new_states = successive_states.Last(); | |||||
| outputs = tf.stack(successive_outputs); | outputs = tf.stack(successive_outputs); | ||||
| } | } | ||||
| } | } | ||||
| @@ -701,6 +699,8 @@ namespace Tensorflow.Keras | |||||
| // Create input tensor array, if the inputs is nested tensors, then it | // Create input tensor array, if the inputs is nested tensors, then it | ||||
| // will be flattened first, and tensor array will be created one per | // will be flattened first, and tensor array will be created one per | ||||
| // flattened tensor. | // flattened tensor. | ||||
| var input_ta = new List<TensorArray>(); | var input_ta = new List<TensorArray>(); | ||||
| for (int i = 0; i < flatted_inptus.Count; i++) | for (int i = 0; i < flatted_inptus.Count; i++) | ||||
| { | { | ||||
| @@ -719,6 +719,7 @@ namespace Tensorflow.Keras | |||||
| } | } | ||||
| } | } | ||||
| // Get the time(0) input and compute the output for that, the output will | // Get the time(0) input and compute the output for that, the output will | ||||
| // be used to determine the dtype of output tensor array. Don't read from | // be used to determine the dtype of output tensor array. Don't read from | ||||
| // input_ta due to TensorArray clear_after_read default to True. | // input_ta due to TensorArray clear_after_read default to True. | ||||
| @@ -773,7 +774,7 @@ namespace Tensorflow.Keras | |||||
| return res; | return res; | ||||
| }; | }; | ||||
| } | } | ||||
| // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? | |||||
| // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor | |||||
| else if (input_length is Tensor) | else if (input_length is Tensor) | ||||
| { | { | ||||
| if (go_backwards) | if (go_backwards) | ||||
| @@ -685,6 +685,34 @@ namespace Tensorflow.Keras.Layers | |||||
| Alpha = alpha | Alpha = alpha | ||||
| }); | }); | ||||
| public IRnnCell SimpleRNNCell( | |||||
| int units, | |||||
| string activation = "tanh", | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros", | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f) | |||||
| => new SimpleRNNCell(new SimpleRNNCellArgs | |||||
| { | |||||
| Units = units, | |||||
| Activation = keras.activations.GetActivationFromName(activation), | |||||
| UseBias = use_bias, | |||||
| KernelInitializer = GetInitializerByName(kernel_initializer), | |||||
| RecurrentInitializer = GetInitializerByName(recurrent_initializer), | |||||
| Dropout = dropout, | |||||
| RecurrentDropout = recurrent_dropout | |||||
| }); | |||||
| public IRnnCell StackedRNNCells( | |||||
| IEnumerable<IRnnCell> cells) | |||||
| => new StackedRNNCells(new StackedRNNCellsArgs | |||||
| { | |||||
| Cells = cells.ToList() | |||||
| }); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -709,6 +737,55 @@ namespace Tensorflow.Keras.Layers | |||||
| ReturnState = return_state | ReturnState = return_state | ||||
| }); | }); | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="cell"></param> | |||||
| /// <param name="return_sequences"></param> | |||||
| /// <param name="return_state"></param> | |||||
| /// <param name="go_backwards"></param> | |||||
| /// <param name="stateful"></param> | |||||
| /// <param name="unroll"></param> | |||||
| /// <param name="time_major"></param> | |||||
| /// <returns></returns> | |||||
| public ILayer RNN( | |||||
| IRnnCell cell, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false) | |||||
| => new RNN(new RNNArgs | |||||
| { | |||||
| Cell = cell, | |||||
| ReturnSequences = return_sequences, | |||||
| ReturnState = return_state, | |||||
| GoBackwards = go_backwards, | |||||
| Stateful = stateful, | |||||
| Unroll = unroll, | |||||
| TimeMajor = time_major | |||||
| }); | |||||
| public ILayer RNN( | |||||
| IEnumerable<IRnnCell> cell, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false) | |||||
| => new RNN(new RNNArgs | |||||
| { | |||||
| Cells = cell.ToList(), | |||||
| ReturnSequences = return_sequences, | |||||
| ReturnState = return_state, | |||||
| GoBackwards = go_backwards, | |||||
| Stateful = stateful, | |||||
| Unroll = unroll, | |||||
| TimeMajor = time_major | |||||
| }); | |||||
| /// <summary> | /// <summary> | ||||
| /// Long Short-Term Memory layer - Hochreiter 1997. | /// Long Short-Term Memory layer - Hochreiter 1997. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -17,6 +17,21 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| } | } | ||||
| protected void _create_non_trackable_mask_cache() | |||||
| { | |||||
| } | |||||
| public void reset_dropout_mask() | |||||
| { | |||||
| } | |||||
| public void reset_recurrent_dropout_mask() | |||||
| { | |||||
| } | |||||
| public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | ||||
| { | { | ||||
| if (dropout == 0f) | if (dropout == 0f) | ||||
| @@ -38,7 +38,17 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| SupportsMasking = true; | SupportsMasking = true; | ||||
| // if is StackedRnncell | // if is StackedRnncell | ||||
| _cell = args.Cell; | |||||
| if (args.Cells != null) | |||||
| { | |||||
| _cell = new StackedRNNCells(new StackedRNNCellsArgs | |||||
| { | |||||
| Cells = args.Cells | |||||
| }); | |||||
| } | |||||
| else | |||||
| { | |||||
| _cell = args.Cell; | |||||
| } | |||||
| // get input_shape | // get input_shape | ||||
| _args = PreConstruct(args); | _args = PreConstruct(args); | ||||
| @@ -122,6 +132,8 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); | var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); | ||||
| return new Shape(state_shape); | return new Shape(state_shape); | ||||
| }; | }; | ||||
| var state_shape = _get_state_shape(state_size); | var state_shape = _get_state_shape(state_size); | ||||
| return new List<Shape> { output_shape, state_shape }; | return new List<Shape> { output_shape, state_shape }; | ||||
| @@ -240,7 +252,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| if (_cell is StackedRNNCells) | if (_cell is StackedRNNCells) | ||||
| { | { | ||||
| var stack_cell = _cell as StackedRNNCells; | var stack_cell = _cell as StackedRNNCells; | ||||
| foreach (var cell in stack_cell.Cells) | |||||
| foreach (IRnnCell cell in stack_cell.Cells) | |||||
| { | { | ||||
| _maybe_reset_cell_dropout_mask(cell); | _maybe_reset_cell_dropout_mask(cell); | ||||
| } | } | ||||
| @@ -253,7 +265,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| } | } | ||||
| Shape input_shape; | Shape input_shape; | ||||
| if (!inputs.IsSingle()) | |||||
| if (!inputs.IsNested()) | |||||
| { | { | ||||
| // In the case of nested input, use the first element for shape check | // In the case of nested input, use the first element for shape check | ||||
| // input_shape = nest.flatten(inputs)[0].shape; | // input_shape = nest.flatten(inputs)[0].shape; | ||||
| @@ -267,7 +279,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; | var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; | ||||
| if (_args.Unroll && timesteps != null) | |||||
| if (_args.Unroll && timesteps == null) | |||||
| { | { | ||||
| throw new ValueError( | throw new ValueError( | ||||
| "Cannot unroll a RNN if the " + | "Cannot unroll a RNN if the " + | ||||
| @@ -302,7 +314,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| states = new Tensors(states.SkipLast(_num_constants)); | states = new Tensors(states.SkipLast(_num_constants)); | ||||
| states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; | states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; | ||||
| var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | ||||
| // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors? | |||||
| return (output, new_states.Single); | return (output, new_states.Single); | ||||
| }; | }; | ||||
| } | } | ||||
| @@ -310,13 +321,14 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| { | { | ||||
| step = (inputs, states) => | step = (inputs, states) => | ||||
| { | { | ||||
| states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; | |||||
| states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states; | |||||
| var (output, new_states) = _cell.Apply(inputs, states); | var (output, new_states) = _cell.Apply(inputs, states); | ||||
| return (output, new_states.Single); | |||||
| return (output, new_states); | |||||
| }; | }; | ||||
| } | } | ||||
| var (last_output, outputs, states) = keras.backend.rnn(step, | |||||
| var (last_output, outputs, states) = keras.backend.rnn( | |||||
| step, | |||||
| inputs, | inputs, | ||||
| initial_state, | initial_state, | ||||
| constants: constants, | constants: constants, | ||||
| @@ -394,6 +406,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| initial_state = null; | initial_state = null; | ||||
| inputs = inputs[0]; | inputs = inputs[0]; | ||||
| } | } | ||||
| if (_args.Stateful) | if (_args.Stateful) | ||||
| { | { | ||||
| @@ -402,7 +415,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| var tmp = new Tensor[] { }; | var tmp = new Tensor[] { }; | ||||
| foreach (var s in nest.flatten(States)) | foreach (var s in nest.flatten(States)) | ||||
| { | { | ||||
| tmp.add(tf.math.count_nonzero((Tensor)s)); | |||||
| tmp.add(tf.math.count_nonzero(s.Single())); | |||||
| } | } | ||||
| var non_zero_count = tf.add_n(tmp); | var non_zero_count = tf.add_n(tmp); | ||||
| //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); | //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); | ||||
| @@ -415,6 +428,15 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| { | { | ||||
| initial_state = States; | initial_state = States; | ||||
| } | } | ||||
| // TODO(Wanglongzhi2001), | |||||
| // initial_state = tf.nest.map_structure( | |||||
| //# When the layer has a inferred dtype, use the dtype from the | |||||
| //# cell. | |||||
| // lambda v: tf.cast( | |||||
| // v, self.compute_dtype or self.cell.compute_dtype | |||||
| // ), | |||||
| // initial_state, | |||||
| // ) | |||||
| } | } | ||||
| else if (initial_state is null) | else if (initial_state is null) | ||||
| @@ -424,10 +446,9 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| if (initial_state.Length != States.Length) | if (initial_state.Length != States.Length) | ||||
| { | { | ||||
| throw new ValueError( | |||||
| $"Layer {this} expects {States.Length} state(s), " + | |||||
| $"but it received {initial_state.Length} " + | |||||
| $"initial state(s). Input received: {inputs}"); | |||||
| throw new ValueError($"Layer {this} expects {States.Length} state(s), " + | |||||
| $"but it received {initial_state.Length} " + | |||||
| $"initial state(s). Input received: {inputs}"); | |||||
| } | } | ||||
| return (inputs, initial_state, constants); | return (inputs, initial_state, constants); | ||||
| @@ -458,11 +479,11 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| void _maybe_reset_cell_dropout_mask(ILayer cell) | void _maybe_reset_cell_dropout_mask(ILayer cell) | ||||
| { | { | ||||
| //if (cell is DropoutRNNCellMixin) | |||||
| //{ | |||||
| // cell.reset_dropout_mask(); | |||||
| // cell.reset_recurrent_dropout_mask(); | |||||
| //} | |||||
| if (cell is DropoutRNNCellMixin CellDRCMixin) | |||||
| { | |||||
| CellDRCMixin.reset_dropout_mask(); | |||||
| CellDRCMixin.reset_recurrent_dropout_mask(); | |||||
| } | |||||
| } | } | ||||
| private static RNNArgs PreConstruct(RNNArgs args) | private static RNNArgs PreConstruct(RNNArgs args) | ||||
| @@ -537,15 +558,24 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| protected Tensors get_initial_state(Tensors inputs) | protected Tensors get_initial_state(Tensors inputs) | ||||
| { | { | ||||
| var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state"); | |||||
| var input = inputs[0]; | var input = inputs[0]; | ||||
| var input_shape = input.shape; | |||||
| var input_shape = inputs.shape; | |||||
| var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; | var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; | ||||
| var dtype = input.dtype; | var dtype = input.dtype; | ||||
| Tensors init_state; | |||||
| if (_cell is RnnCellBase rnn_base_cell) | |||||
| Tensors init_state = new Tensors(); | |||||
| if(get_initial_state_fn != null) | |||||
| { | { | ||||
| init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype); | |||||
| init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype }); | |||||
| } | } | ||||
| //if (_cell is RnnCellBase rnn_base_cell) | |||||
| //{ | |||||
| // init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype); | |||||
| //} | |||||
| else | else | ||||
| { | { | ||||
| init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype); | init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype); | ||||
| @@ -6,6 +6,7 @@ using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
| using Tensorflow.Common.Extensions; | using Tensorflow.Common.Extensions; | ||||
| using Tensorflow.Keras.Utils; | |||||
| namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
| { | { | ||||
| @@ -77,8 +78,10 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); | var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); | ||||
| Tensor h; | Tensor h; | ||||
| var ranks = inputs.rank; | |||||
| if (dp_mask != null) | if (dp_mask != null) | ||||
| { | { | ||||
| h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor()); | h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor()); | ||||
| } | } | ||||
| else | else | ||||
| @@ -95,7 +98,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| { | { | ||||
| prev_output = math_ops.multiply(prev_output, rec_dp_mask); | prev_output = math_ops.multiply(prev_output, rec_dp_mask); | ||||
| } | } | ||||
| var tmp = _recurrent_kernel.AsTensor(); | |||||
| Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); | Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); | ||||
| if (_args.Activation != null) | if (_args.Activation != null) | ||||
| @@ -113,5 +116,10 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| return new Tensors(output, output); | return new Tensors(output, output); | ||||
| } | } | ||||
| } | } | ||||
| public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) | |||||
| { | |||||
| return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,17 +1,20 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.ComponentModel; | using System.ComponentModel; | ||||
| using System.Linq; | |||||
| using Tensorflow.Common.Extensions; | |||||
| using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.Keras.Utils; | |||||
| namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
| { | { | ||||
| public class StackedRNNCells : Layer, IRnnCell | public class StackedRNNCells : Layer, IRnnCell | ||||
| { | { | ||||
| public IList<RnnCell> Cells { get; set; } | |||||
| public IList<IRnnCell> Cells { get; set; } | |||||
| public bool reverse_state_order; | public bool reverse_state_order; | ||||
| public StackedRNNCells(StackedRNNCellsArgs args) : base(args) | public StackedRNNCells(StackedRNNCellsArgs args) : base(args) | ||||
| @@ -20,8 +23,19 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| { | { | ||||
| args.Kwargs = new Dictionary<string, object>(); | args.Kwargs = new Dictionary<string, object>(); | ||||
| } | } | ||||
| foreach (var cell in args.Cells) | |||||
| { | |||||
| //Type type = cell.GetType(); | |||||
| //var CallMethodInfo = type.GetMethod("Call"); | |||||
| //if (CallMethodInfo == null) | |||||
| //{ | |||||
| // throw new ValueError( | |||||
| // "All cells must have a `Call` method. " + | |||||
| // $"Received cell without a `Call` method: {cell}"); | |||||
| //} | |||||
| } | |||||
| Cells = args.Cells; | Cells = args.Cells; | ||||
| reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false); | reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false); | ||||
| if (reverse_state_order) | if (reverse_state_order) | ||||
| @@ -33,91 +47,112 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| } | } | ||||
| } | } | ||||
| public object state_size | |||||
| public GeneralizedTensorShape StateSize | |||||
| { | { | ||||
| get => throw new NotImplementedException(); | |||||
| //@property | |||||
| //def state_size(self) : | |||||
| // return tuple(c.state_size for c in | |||||
| // (self.cells[::- 1] if self.reverse_state_order else self.cells)) | |||||
| get | |||||
| { | |||||
| GeneralizedTensorShape state_size = new GeneralizedTensorShape(1, Cells.Count); | |||||
| if (reverse_state_order && Cells.Count > 0) | |||||
| { | |||||
| var idxAndCell = Cells.Reverse().Select((cell, idx) => (idx, cell)); | |||||
| foreach (var cell in idxAndCell) | |||||
| { | |||||
| state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First(); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| //foreach (var cell in Cells) | |||||
| //{ | |||||
| // state_size.Shapes.add(cell.StateSize.Shapes.First()); | |||||
| //} | |||||
| var idxAndCell = Cells.Select((cell, idx) => (idx, cell)); | |||||
| foreach (var cell in idxAndCell) | |||||
| { | |||||
| state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First(); | |||||
| } | |||||
| } | |||||
| return state_size; | |||||
| } | |||||
| } | } | ||||
| public object output_size | public object output_size | ||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| var lastCell = Cells[Cells.Count - 1]; | |||||
| if (lastCell.output_size != -1) | |||||
| var lastCell = Cells.LastOrDefault(); | |||||
| if (lastCell.OutputSize.ToSingleShape() != -1) | |||||
| { | { | ||||
| return lastCell.output_size; | |||||
| return lastCell.OutputSize; | |||||
| } | } | ||||
| else if (RNN.is_multiple_state(lastCell.StateSize)) | else if (RNN.is_multiple_state(lastCell.StateSize)) | ||||
| { | { | ||||
| // return ((dynamic)Cells[-1].state_size)[0]; | |||||
| throw new NotImplementedException(""); | |||||
| return lastCell.StateSize.First(); | |||||
| //throw new NotImplementedException(""); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return Cells[-1].state_size; | |||||
| return lastCell.StateSize; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| public object get_initial_state() | |||||
| public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| // def get_initial_state(self, inputs= None, batch_size= None, dtype= None) : | |||||
| // initial_states = [] | |||||
| // for cell in self.cells[::- 1] if self.reverse_state_order else self.cells: | |||||
| // get_initial_state_fn = getattr(cell, 'get_initial_state', None) | |||||
| // if get_initial_state_fn: | |||||
| // initial_states.append(get_initial_state_fn( | |||||
| // inputs=inputs, batch_size=batch_size, dtype=dtype)) | |||||
| // else: | |||||
| // initial_states.append(_generate_zero_filled_state_for_cell( | |||||
| // cell, inputs, batch_size, dtype)) | |||||
| // return tuple(initial_states) | |||||
| var cells = reverse_state_order ? Cells.Reverse() : Cells; | |||||
| Tensors initial_states = new Tensors(); | |||||
| foreach (var cell in cells) | |||||
| { | |||||
| var get_initial_state_fn = cell.GetType().GetMethod("get_initial_state"); | |||||
| if (get_initial_state_fn != null) | |||||
| { | |||||
| var result = (Tensors)get_initial_state_fn.Invoke(cell, new object[] { inputs, batch_size, dtype }); | |||||
| initial_states.Add(result); | |||||
| } | |||||
| else | |||||
| { | |||||
| initial_states.Add(RnnUtils.generate_zero_filled_state_for_cell(cell, inputs, batch_size.Value, dtype.Value)); | |||||
| } | |||||
| } | |||||
| return initial_states; | |||||
| } | } | ||||
| public object call() | |||||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| // def call(self, inputs, states, constants= None, training= None, ** kwargs): | |||||
| // # Recover per-cell states. | |||||
| // state_size = (self.state_size[::- 1] | |||||
| // if self.reverse_state_order else self.state_size) | |||||
| // nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) | |||||
| // # Call the cells in order and store the returned states. | |||||
| // new_nested_states = [] | |||||
| // for cell, states in zip(self.cells, nested_states) : | |||||
| // states = states if nest.is_nested(states) else [states] | |||||
| //# TF cell does not wrap the state into list when there is only one state. | |||||
| // is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None | |||||
| // states = states[0] if len(states) == 1 and is_tf_rnn_cell else states | |||||
| // if generic_utils.has_arg(cell.call, 'training'): | |||||
| // kwargs['training'] = training | |||||
| // else: | |||||
| // kwargs.pop('training', None) | |||||
| // # Use the __call__ function for callable objects, eg layers, so that it | |||||
| // # will have the proper name scopes for the ops, etc. | |||||
| // cell_call_fn = cell.__call__ if callable(cell) else cell.call | |||||
| // if generic_utils.has_arg(cell.call, 'constants'): | |||||
| // inputs, states = cell_call_fn(inputs, states, | |||||
| // constants= constants, ** kwargs) | |||||
| // else: | |||||
| // inputs, states = cell_call_fn(inputs, states, ** kwargs) | |||||
| // new_nested_states.append(states) | |||||
| // Recover per-cell states. | |||||
| var state_size = reverse_state_order ? StateSize.Reverse() : StateSize; | |||||
| var nested_states = reverse_state_order ? state.Flatten().Reverse() : state.Flatten(); | |||||
| // return inputs, nest.pack_sequence_as(state_size, | |||||
| // nest.flatten(new_nested_states)) | |||||
| var new_nest_states = new Tensors(); | |||||
| // Call the cells in order and store the returned states. | |||||
| foreach (var (cell, states) in zip(Cells, nested_states)) | |||||
| { | |||||
| // states = states if tf.nest.is_nested(states) else [states] | |||||
| var type = cell.GetType(); | |||||
| bool IsTFRnnCell = type.GetProperty("IsTFRnnCell") != null; | |||||
| state = len(state) == 1 && IsTFRnnCell ? state.FirstOrDefault() : state; | |||||
| RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; | |||||
| Tensors? constants = rnn_optional_args?.Constants; | |||||
| Tensors new_states; | |||||
| (inputs, new_states) = cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | |||||
| new_nest_states.Add(new_states); | |||||
| } | |||||
| new_nest_states = reverse_state_order ? new_nest_states.Reverse().ToArray() : new_nest_states.ToArray(); | |||||
| return new Nest<Tensor>(new List<Nest<Tensor>> { | |||||
| new Nest<Tensor>(new List<Nest<Tensor>> { new Nest<Tensor>(inputs.Single()) }), new Nest<Tensor>(new_nest_states) }) | |||||
| .ToTensors(); | |||||
| } | } | ||||
| public void build() | public void build() | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| built = true; | |||||
| // @tf_utils.shape_type_conversion | // @tf_utils.shape_type_conversion | ||||
| // def build(self, input_shape) : | // def build(self, input_shape) : | ||||
| // if isinstance(input_shape, list) : | // if isinstance(input_shape, list) : | ||||
| @@ -168,9 +203,9 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| public GeneralizedTensorShape StateSize => throw new NotImplementedException(); | |||||
| public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); | public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); | ||||
| public bool IsTFRnnCell => throw new NotImplementedException(); | |||||
| public bool IsTFRnnCell => true; | |||||
| public bool SupportOptionalArgs => throw new NotImplementedException(); | public bool SupportOptionalArgs => throw new NotImplementedException(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -2,6 +2,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.Callbacks; | using Tensorflow.Keras.Callbacks; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.NumPy; | |||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| @@ -18,7 +19,7 @@ namespace Tensorflow.Keras.UnitTest.Callbacks | |||||
| var layers = keras.layers; | var layers = keras.layers; | ||||
| var model = keras.Sequential(new List<ILayer> | var model = keras.Sequential(new List<ILayer> | ||||
| { | { | ||||
| layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)), | |||||
| layers.Rescaling(1.0f / 255, input_shape: (28, 28, 1)), | |||||
| layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), | layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), | ||||
| layers.MaxPooling2D(), | layers.MaxPooling2D(), | ||||
| layers.Flatten(), | layers.Flatten(), | ||||
| @@ -36,8 +37,20 @@ namespace Tensorflow.Keras.UnitTest.Callbacks | |||||
| var num_epochs = 3; | var num_epochs = 3; | ||||
| var batch_size = 8; | var batch_size = 8; | ||||
| var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); | |||||
| x_train = x_train / 255.0f; | |||||
| var data_loader = new MnistModelLoader(); | |||||
| var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
| { | |||||
| TrainDir = "mnist", | |||||
| OneHot = false, | |||||
| ValidationSize = 59900, | |||||
| }).Result; | |||||
| NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1)); | |||||
| NDArray x2 = x1; | |||||
| var x = new NDArray[] { x1, x2 }; | |||||
| // define a CallbackParams first, the parameters you pass al least contain Model and Epochs. | // define a CallbackParams first, the parameters you pass al least contain Model and Epochs. | ||||
| CallbackParams callback_parameters = new CallbackParams | CallbackParams callback_parameters = new CallbackParams | ||||
| { | { | ||||
| @@ -47,10 +60,8 @@ namespace Tensorflow.Keras.UnitTest.Callbacks | |||||
| // define your earlystop | // define your earlystop | ||||
| ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); | ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); | ||||
| // define a callbcaklist, then add the earlystopping to it. | // define a callbcaklist, then add the earlystopping to it. | ||||
| var callbacks = new List<ICallback>(); | |||||
| callbacks.add(earlystop); | |||||
| model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks); | |||||
| var callbacks = new List<ICallback>{ earlystop}; | |||||
| model.fit(x, dataset.Train.Labels, batch_size, num_epochs, callbacks: callbacks); | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,25 +4,111 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
| using Tensorflow.Common.Types; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Layers.Rnn; | |||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.UnitTest.Layers | namespace Tensorflow.Keras.UnitTest.Layers | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class Rnn | public class Rnn | ||||
| { | { | ||||
| [TestMethod] | |||||
| public void SimpleRNNCell() | |||||
| { | |||||
| //var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); | |||||
| //var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | |||||
| //var x = tf.random.normal((4, 100)); | |||||
| //var (y, h1) = cell.Apply(inputs: x, states: h0); | |||||
| //var h2 = h1; | |||||
| //Assert.AreEqual((4, 64), y.shape); | |||||
| //Assert.AreEqual((4, 64), h2[0].shape); | |||||
| //var model = keras.Sequential(new List<ILayer> | |||||
| //{ | |||||
| // keras.layers.InputLayer(input_shape: (4,100)), | |||||
| // keras.layers.SimpleRNNCell(64) | |||||
| //}); | |||||
| //model.summary(); | |||||
| var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); | |||||
| var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | |||||
| var x = tf.random.normal((4, 100)); | |||||
| var (y, h1) = cell.Apply(inputs: x, states: h0); | |||||
| var h2 = h1; | |||||
| Assert.AreEqual((4, 64), y.shape); | |||||
| Assert.AreEqual((4, 64), h2[0].shape); | |||||
| } | |||||
| [TestMethod] | |||||
| public void StackedRNNCell() | |||||
| { | |||||
| var inputs = tf.ones((32, 10)); | |||||
| var states = new Tensors { tf.zeros((32, 4)), tf.zeros((32, 5)) }; | |||||
| var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; | |||||
| var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); | |||||
| var (output, state) = stackedRNNCell.Apply(inputs, states); | |||||
| Console.WriteLine(output); | |||||
| Console.WriteLine(state.shape); | |||||
| Assert.AreEqual((32, 5), output.shape); | |||||
| Assert.AreEqual((32, 4), state[0].shape); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void SimpleRNN() | public void SimpleRNN() | ||||
| { | { | ||||
| var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); | |||||
| /*var simple_rnn = keras.layers.SimpleRNN(4); | |||||
| var output = simple_rnn.Apply(inputs); | |||||
| Assert.AreEqual((32, 4), output.shape);*/ | |||||
| var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); | |||||
| var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); | |||||
| Console.WriteLine(whole_sequence_output); | |||||
| Console.WriteLine(final_state); | |||||
| //var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); | |||||
| ///*var simple_rnn = keras.layers.SimpleRNN(4); | |||||
| //var output = simple_rnn.Apply(inputs); | |||||
| //Assert.AreEqual((32, 4), output.shape);*/ | |||||
| //var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); | |||||
| //var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); | |||||
| //Assert.AreEqual((6, 10, 4), whole_sequence_output.shape); | |||||
| //Assert.AreEqual((6, 4), final_state.shape); | |||||
| var inputs = keras.Input(shape: (10, 8)); | |||||
| var x = keras.layers.SimpleRNN(4).Apply(inputs); | |||||
| var output = keras.layers.Dense(10).Apply(x); | |||||
| var model = keras.Model(inputs, output); | |||||
| model.summary(); | |||||
| } | |||||
| [TestMethod] | |||||
| public void RNNForSimpleRNNCell() | |||||
| { | |||||
| var inputs = tf.random.normal((32, 10, 8)); | |||||
| var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f); | |||||
| var rnn = tf.keras.layers.RNN(cell: cell); | |||||
| var output = rnn.Apply(inputs); | |||||
| Assert.AreEqual((32, 10), output.shape); | |||||
| } | } | ||||
| [TestMethod] | |||||
| public void RNNForStackedRNNCell() | |||||
| { | |||||
| var inputs = tf.random.normal((32, 10, 8)); | |||||
| var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; | |||||
| var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); | |||||
| var rnn = tf.keras.layers.RNN(cell: stackedRNNCell); | |||||
| var output = rnn.Apply(inputs); | |||||
| Assert.AreEqual((32, 5), output.shape); | |||||
| } | |||||
| [TestMethod] | |||||
| public void WlzTest() | |||||
| { | |||||
| long[] b = { 1, 2, 3 }; | |||||
| Shape a = new Shape(Unknown).concatenate(b); | |||||
| Console.WriteLine(a); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||