From e8773462a47db898b501d028147202a715f9195a Mon Sep 17 00:00:00 2001 From: MPnoy Date: Mon, 1 Feb 2021 03:58:59 +0300 Subject: [PATCH] Blank SimpleRNN and test for it --- .../Keras/ArgsDefinition/RNNArgs.cs | 17 ++- .../Keras/ArgsDefinition/SimpleRNNArgs.cs | 30 +++++ .../ArgsDefinition/StackedRNNCellsArgs.cs | 9 ++ .../Operations/NnOps/RNNCell.cs | 2 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 19 +++ src/TensorFlowNET.Keras/Layers/RNN.cs | 86 +++++++++++- src/TensorFlowNET.Keras/Layers/SimpleRNN.cs | 14 ++ .../Layers/StackedRNNCells.cs | 125 ++++++++++++++++++ .../Layers/LayersTest.cs | 13 +- 9 files changed, 307 insertions(+), 8 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs create mode 100644 src/TensorFlowNET.Keras/Layers/SimpleRNN.cs create mode 100644 src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs index 623cc68e..3ebcf617 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs @@ -1,6 +1,21 @@ -namespace Tensorflow.Keras.ArgsDefinition +using System.Collections.Generic; + +namespace Tensorflow.Keras.ArgsDefinition { public class RNNArgs : LayerArgs { + public interface IRnnArgCell : ILayer + { + object state_size { get; } + } + + public IRnnArgCell Cell { get; set; } = null; + public bool ReturnSequences { get; set; } = false; + public bool ReturnState { get; set; } = false; + public bool GoBackwards { get; set; } = false; + public bool Stateful { get; set; } = false; + public bool Unroll { get; set; } = false; + public bool TimeMajor { get; set; } = false; + public Dictionary Kwargs { get; set; } = null; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs new file mode 100644 index 00000000..65815587 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs @@ -0,0 +1,30 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class SimpleRNNArgs : RNNArgs + { + public int Units { get; set; } + public Activation Activation { get; set; } + + // units, + // activation='tanh', + // use_bias=True, + // kernel_initializer='glorot_uniform', + // recurrent_initializer='orthogonal', + // bias_initializer='zeros', + // kernel_regularizer=None, + // recurrent_regularizer=None, + // bias_regularizer=None, + // activity_regularizer=None, + // kernel_constraint=None, + // recurrent_constraint=None, + // bias_constraint=None, + // dropout=0., + // recurrent_dropout=0., + // return_sequences=False, + // return_state=False, + // go_backwards=False, + // stateful=False, + // unroll=False, + // **kwargs): + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs new file mode 100644 index 00000000..1c52e47b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs @@ -0,0 +1,9 @@ +using System.Collections.Generic; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class StackedRNNCellsArgs : LayerArgs + { + public IList Cells { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index aaea5cd2..0dd40096 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -46,7 +46,7 @@ namespace Tensorflow /// matching structure of Tensors having shape `[batch_size].concatenate(s)` /// for each `s` in `self.batch_size`. /// - public abstract class RnnCell : ILayer + public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell { /// /// Attribute that indicates whether the cell is a TF RNN cell, due the slight diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index a2a29770..3f8fae3d 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -1,4 +1,5 @@ using NumSharp; +using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; @@ -327,6 +328,24 @@ namespace Tensorflow.Keras.Layers Alpha = alpha }); + public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh"); + + public Layer SimpleRNN(int units, + Activation activation = null) + => new SimpleRNN(new SimpleRNNArgs + { + Units = units, + Activation = activation + }); + + public Layer SimpleRNN(int units, + string activation = "tanh") + => new SimpleRNN(new SimpleRNNArgs + { + Units = units, + Activation = GetActivationByName(activation) + }); + public Layer LSTM(int units, Activation activation = null, Activation recurrent_activation = null, diff --git a/src/TensorFlowNET.Keras/Layers/RNN.cs b/src/TensorFlowNET.Keras/Layers/RNN.cs index 3d03abb1..0c77d57f 100644 --- a/src/TensorFlowNET.Keras/Layers/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/RNN.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; @@ -6,12 +7,93 @@ namespace Tensorflow.Keras.Layers { public class RNN : Layer { - public RNN(RNNArgs args) - : base(args) + private RNNArgs args; + + public RNN(RNNArgs args) : base(PreConstruct(args)) { + this.args = args; + SupportsMasking = true; + + // The input shape is unknown yet, it could have nested tensor inputs, and + // the input spec will be the list of specs for nested inputs, the structure + // of the input_spec will be the same as the input. + //self.input_spec = None + //self.state_spec = None + //self._states = None + //self.constants_spec = None + //self._num_constants = 0 + + //if stateful: + // if ds_context.has_strategy(): + // raise ValueError('RNNs with stateful=True not yet supported with ' + // 'tf.distribute.Strategy.') } + private static RNNArgs PreConstruct(RNNArgs args) + { + if (args.Kwargs == null) + { + args.Kwargs = new Dictionary(); + } + + // If true, the output for masked timestep will be zeros, whereas in the + // false case, output from previous timestep is returned for masked timestep. + var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false); + + object input_shape; + var propIS = args.Kwargs.Get("input_shape", null); + var propID = args.Kwargs.Get("input_dim", null); + var propIL = args.Kwargs.Get("input_length", null); + + if (propIS == null && (propID != null || propIL != null)) + { + input_shape = ( + propIL ?? new NoneValue(), // maybe null is needed here + propID ?? new NoneValue()); // and here + args.Kwargs["input_shape"] = input_shape; + } + + return args; + } + + public RNN New(LayerRnnCell cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false) + => new RNN(new RNNArgs + { + Cell = cell, + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + Unroll = unroll, + TimeMajor = time_major + }); + + public RNN New(IList cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false) + => new RNN(new RNNArgs + { + Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }), + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + Unroll = unroll, + TimeMajor = time_major + }); + + protected Tensor get_initial_state(Tensor inputs) { return _generate_zero_filled_state_for_cell(null, null); diff --git a/src/TensorFlowNET.Keras/Layers/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/SimpleRNN.cs new file mode 100644 index 00000000..c1fc4afd --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/SimpleRNN.cs @@ -0,0 +1,14 @@ +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Layers +{ + public class SimpleRNN : RNN + { + + public SimpleRNN(RNNArgs args) : base(args) + { + + } + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs new file mode 100644 index 00000000..c0a2371f --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell + { + public IList Cells { get; set; } + + public StackedRNNCells(StackedRNNCellsArgs args) : base(args) + { + Cells = args.Cells; + //Cells.reverse_state_order = kwargs.pop('reverse_state_order', False); + // self.reverse_state_order = kwargs.pop('reverse_state_order', False) + // if self.reverse_state_order: + // logging.warning('reverse_state_order=True in StackedRNNCells will soon ' + // 'be deprecated. Please update the code to work with the ' + // 'natural order of states if you rely on the RNN states, ' + // 'eg RNN(return_state=True).') + // super(StackedRNNCells, self).__init__(**kwargs) + throw new NotImplementedException(""); + } + + public object state_size + { + get => throw new NotImplementedException(); + } + + //@property + //def state_size(self) : + // return tuple(c.state_size for c in + // (self.cells[::- 1] if self.reverse_state_order else self.cells)) + + // @property + // def output_size(self) : + // if getattr(self.cells[-1], 'output_size', None) is not None: + // return self.cells[-1].output_size + // elif _is_multiple_state(self.cells[-1].state_size) : + // return self.cells[-1].state_size[0] + // else: + // return self.cells[-1].state_size + + // def get_initial_state(self, inputs= None, batch_size= None, dtype= None) : + // initial_states = [] + // for cell in self.cells[::- 1] if self.reverse_state_order else self.cells: + // get_initial_state_fn = getattr(cell, 'get_initial_state', None) + // if get_initial_state_fn: + // initial_states.append(get_initial_state_fn( + // inputs=inputs, batch_size=batch_size, dtype=dtype)) + // else: + // initial_states.append(_generate_zero_filled_state_for_cell( + // cell, inputs, batch_size, dtype)) + + // return tuple(initial_states) + + // def call(self, inputs, states, constants= None, training= None, ** kwargs): + // # Recover per-cell states. + // state_size = (self.state_size[::- 1] + // if self.reverse_state_order else self.state_size) + // nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) + + // # Call the cells in order and store the returned states. + // new_nested_states = [] + // for cell, states in zip(self.cells, nested_states) : + // states = states if nest.is_nested(states) else [states] + //# TF cell does not wrap the state into list when there is only one state. + // is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None + // states = states[0] if len(states) == 1 and is_tf_rnn_cell else states + // if generic_utils.has_arg(cell.call, 'training'): + // kwargs['training'] = training + // else: + // kwargs.pop('training', None) + // # Use the __call__ function for callable objects, eg layers, so that it + // # will have the proper name scopes for the ops, etc. + // cell_call_fn = cell.__call__ if callable(cell) else cell.call + // if generic_utils.has_arg(cell.call, 'constants'): + // inputs, states = cell_call_fn(inputs, states, + // constants= constants, ** kwargs) + // else: + // inputs, states = cell_call_fn(inputs, states, ** kwargs) + // new_nested_states.append(states) + + // return inputs, nest.pack_sequence_as(state_size, + // nest.flatten(new_nested_states)) + + // @tf_utils.shape_type_conversion + // def build(self, input_shape) : + // if isinstance(input_shape, list) : + // input_shape = input_shape[0] + // for cell in self.cells: + // if isinstance(cell, Layer) and not cell.built: + // with K.name_scope(cell.name): + // cell.build(input_shape) + // cell.built = True + // if getattr(cell, 'output_size', None) is not None: + // output_dim = cell.output_size + // elif _is_multiple_state(cell.state_size) : + // output_dim = cell.state_size[0] + // else: + // output_dim = cell.state_size + // input_shape = tuple([input_shape[0]] + + // tensor_shape.TensorShape(output_dim).as_list()) + // self.built = True + + // def get_config(self) : + // cells = [] + // for cell in self.cells: + // cells.append(generic_utils.serialize_keras_object(cell)) + // config = {'cells': cells + //} + //base_config = super(StackedRNNCells, self).get_config() + // return dict(list(base_config.items()) + list(config.items())) + + // @classmethod + // def from_config(cls, config, custom_objects = None): + // from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + // cells = [] + // for cell_config in config.pop('cells'): + // cells.append( + // deserialize_layer(cell_config, custom_objects = custom_objects)) + // return cls(cells, **config) + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index 62d9fa5c..63e959f5 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -36,7 +36,7 @@ namespace TensorFlowNET.Keras.UnitTest var model = keras.Model(inputs, outputs, name: "mnist_model"); model.summary(); } - + /// /// Custom layer test, used in Dueling DQN /// @@ -45,10 +45,10 @@ namespace TensorFlowNET.Keras.UnitTest { var layers = keras.layers; var inputs = layers.Input(shape: 24); - var x = layers.Dense(128, activation:"relu").Apply(inputs); + var x = layers.Dense(128, activation: "relu").Apply(inputs); var value = layers.Dense(24).Apply(x); var adv = layers.Dense(1).Apply(x); - + var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true); adv = layers.Subtract().Apply((adv, mean)); var outputs = layers.Add().Apply((value, adv)); @@ -105,9 +105,14 @@ namespace TensorFlowNET.Keras.UnitTest } [TestMethod] + [Ignore] public void SimpleRNN() { - + var inputs = np.random.rand(32, 10, 8).astype(np.float32); + var simple_rnn = keras.layers.SimpleRNN(4); + var output = simple_rnn.Apply(inputs); + Assert.AreEqual((32, 4), output.shape); } + } }