| @@ -1,6 +1,21 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class RNNArgs : LayerArgs | public class RNNArgs : LayerArgs | ||||
| { | { | ||||
| public interface IRnnArgCell : ILayer | |||||
| { | |||||
| object state_size { get; } | |||||
| } | |||||
| public IRnnArgCell Cell { get; set; } = null; | |||||
| public bool ReturnSequences { get; set; } = false; | |||||
| public bool ReturnState { get; set; } = false; | |||||
| public bool GoBackwards { get; set; } = false; | |||||
| public bool Stateful { get; set; } = false; | |||||
| public bool Unroll { get; set; } = false; | |||||
| public bool TimeMajor { get; set; } = false; | |||||
| public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,30 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class SimpleRNNArgs : RNNArgs | |||||
| { | |||||
| public int Units { get; set; } | |||||
| public Activation Activation { get; set; } | |||||
| // units, | |||||
| // activation='tanh', | |||||
| // use_bias=True, | |||||
| // kernel_initializer='glorot_uniform', | |||||
| // recurrent_initializer='orthogonal', | |||||
| // bias_initializer='zeros', | |||||
| // kernel_regularizer=None, | |||||
| // recurrent_regularizer=None, | |||||
| // bias_regularizer=None, | |||||
| // activity_regularizer=None, | |||||
| // kernel_constraint=None, | |||||
| // recurrent_constraint=None, | |||||
| // bias_constraint=None, | |||||
| // dropout=0., | |||||
| // recurrent_dropout=0., | |||||
| // return_sequences=False, | |||||
| // return_state=False, | |||||
| // go_backwards=False, | |||||
| // stateful=False, | |||||
| // unroll=False, | |||||
| // **kwargs): | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,9 @@ | |||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class StackedRNNCellsArgs : LayerArgs | |||||
| { | |||||
| public IList<RnnCell> Cells { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -46,7 +46,7 @@ namespace Tensorflow | |||||
| /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | ||||
| /// for each `s` in `self.batch_size`. | /// for each `s` in `self.batch_size`. | ||||
| /// </summary> | /// </summary> | ||||
| public abstract class RnnCell : ILayer | |||||
| public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | ||||
| @@ -1,4 +1,5 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System.Collections.Generic; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -327,6 +328,24 @@ namespace Tensorflow.Keras.Layers | |||||
| Alpha = alpha | Alpha = alpha | ||||
| }); | }); | ||||
| public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh"); | |||||
| public Layer SimpleRNN(int units, | |||||
| Activation activation = null) | |||||
| => new SimpleRNN(new SimpleRNNArgs | |||||
| { | |||||
| Units = units, | |||||
| Activation = activation | |||||
| }); | |||||
| public Layer SimpleRNN(int units, | |||||
| string activation = "tanh") | |||||
| => new SimpleRNN(new SimpleRNNArgs | |||||
| { | |||||
| Units = units, | |||||
| Activation = GetActivationByName(activation) | |||||
| }); | |||||
| public Layer LSTM(int units, | public Layer LSTM(int units, | ||||
| Activation activation = null, | Activation activation = null, | ||||
| Activation recurrent_activation = null, | Activation recurrent_activation = null, | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| @@ -6,12 +7,93 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| public class RNN : Layer | public class RNN : Layer | ||||
| { | { | ||||
| public RNN(RNNArgs args) | |||||
| : base(args) | |||||
| private RNNArgs args; | |||||
| public RNN(RNNArgs args) : base(PreConstruct(args)) | |||||
| { | { | ||||
| this.args = args; | |||||
| SupportsMasking = true; | |||||
| // The input shape is unknown yet, it could have nested tensor inputs, and | |||||
| // the input spec will be the list of specs for nested inputs, the structure | |||||
| // of the input_spec will be the same as the input. | |||||
| //self.input_spec = None | |||||
| //self.state_spec = None | |||||
| //self._states = None | |||||
| //self.constants_spec = None | |||||
| //self._num_constants = 0 | |||||
| //if stateful: | |||||
| // if ds_context.has_strategy(): | |||||
| // raise ValueError('RNNs with stateful=True not yet supported with ' | |||||
| // 'tf.distribute.Strategy.') | |||||
| } | } | ||||
| private static RNNArgs PreConstruct(RNNArgs args) | |||||
| { | |||||
| if (args.Kwargs == null) | |||||
| { | |||||
| args.Kwargs = new Dictionary<string, object>(); | |||||
| } | |||||
| // If true, the output for masked timestep will be zeros, whereas in the | |||||
| // false case, output from previous timestep is returned for masked timestep. | |||||
| var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false); | |||||
| object input_shape; | |||||
| var propIS = args.Kwargs.Get("input_shape", null); | |||||
| var propID = args.Kwargs.Get("input_dim", null); | |||||
| var propIL = args.Kwargs.Get("input_length", null); | |||||
| if (propIS == null && (propID != null || propIL != null)) | |||||
| { | |||||
| input_shape = ( | |||||
| propIL ?? new NoneValue(), // maybe null is needed here | |||||
| propID ?? new NoneValue()); // and here | |||||
| args.Kwargs["input_shape"] = input_shape; | |||||
| } | |||||
| return args; | |||||
| } | |||||
| public RNN New(LayerRnnCell cell, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false) | |||||
| => new RNN(new RNNArgs | |||||
| { | |||||
| Cell = cell, | |||||
| ReturnSequences = return_sequences, | |||||
| ReturnState = return_state, | |||||
| GoBackwards = go_backwards, | |||||
| Stateful = stateful, | |||||
| Unroll = unroll, | |||||
| TimeMajor = time_major | |||||
| }); | |||||
| public RNN New(IList<RnnCell> cell, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false) | |||||
| => new RNN(new RNNArgs | |||||
| { | |||||
| Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }), | |||||
| ReturnSequences = return_sequences, | |||||
| ReturnState = return_state, | |||||
| GoBackwards = go_backwards, | |||||
| Stateful = stateful, | |||||
| Unroll = unroll, | |||||
| TimeMajor = time_major | |||||
| }); | |||||
| protected Tensor get_initial_state(Tensor inputs) | protected Tensor get_initial_state(Tensor inputs) | ||||
| { | { | ||||
| return _generate_zero_filled_state_for_cell(null, null); | return _generate_zero_filled_state_for_cell(null, null); | ||||
| @@ -0,0 +1,14 @@ | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public class SimpleRNN : RNN | |||||
| { | |||||
| public SimpleRNN(RNNArgs args) : base(args) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,125 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell | |||||
| { | |||||
| public IList<RnnCell> Cells { get; set; } | |||||
| public StackedRNNCells(StackedRNNCellsArgs args) : base(args) | |||||
| { | |||||
| Cells = args.Cells; | |||||
| //Cells.reverse_state_order = kwargs.pop('reverse_state_order', False); | |||||
| // self.reverse_state_order = kwargs.pop('reverse_state_order', False) | |||||
| // if self.reverse_state_order: | |||||
| // logging.warning('reverse_state_order=True in StackedRNNCells will soon ' | |||||
| // 'be deprecated. Please update the code to work with the ' | |||||
| // 'natural order of states if you rely on the RNN states, ' | |||||
| // 'eg RNN(return_state=True).') | |||||
| // super(StackedRNNCells, self).__init__(**kwargs) | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public object state_size | |||||
| { | |||||
| get => throw new NotImplementedException(); | |||||
| } | |||||
| //@property | |||||
| //def state_size(self) : | |||||
| // return tuple(c.state_size for c in | |||||
| // (self.cells[::- 1] if self.reverse_state_order else self.cells)) | |||||
| // @property | |||||
| // def output_size(self) : | |||||
| // if getattr(self.cells[-1], 'output_size', None) is not None: | |||||
| // return self.cells[-1].output_size | |||||
| // elif _is_multiple_state(self.cells[-1].state_size) : | |||||
| // return self.cells[-1].state_size[0] | |||||
| // else: | |||||
| // return self.cells[-1].state_size | |||||
| // def get_initial_state(self, inputs= None, batch_size= None, dtype= None) : | |||||
| // initial_states = [] | |||||
| // for cell in self.cells[::- 1] if self.reverse_state_order else self.cells: | |||||
| // get_initial_state_fn = getattr(cell, 'get_initial_state', None) | |||||
| // if get_initial_state_fn: | |||||
| // initial_states.append(get_initial_state_fn( | |||||
| // inputs=inputs, batch_size=batch_size, dtype=dtype)) | |||||
| // else: | |||||
| // initial_states.append(_generate_zero_filled_state_for_cell( | |||||
| // cell, inputs, batch_size, dtype)) | |||||
| // return tuple(initial_states) | |||||
| // def call(self, inputs, states, constants= None, training= None, ** kwargs): | |||||
| // # Recover per-cell states. | |||||
| // state_size = (self.state_size[::- 1] | |||||
| // if self.reverse_state_order else self.state_size) | |||||
| // nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) | |||||
| // # Call the cells in order and store the returned states. | |||||
| // new_nested_states = [] | |||||
| // for cell, states in zip(self.cells, nested_states) : | |||||
| // states = states if nest.is_nested(states) else [states] | |||||
| //# TF cell does not wrap the state into list when there is only one state. | |||||
| // is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None | |||||
| // states = states[0] if len(states) == 1 and is_tf_rnn_cell else states | |||||
| // if generic_utils.has_arg(cell.call, 'training'): | |||||
| // kwargs['training'] = training | |||||
| // else: | |||||
| // kwargs.pop('training', None) | |||||
| // # Use the __call__ function for callable objects, eg layers, so that it | |||||
| // # will have the proper name scopes for the ops, etc. | |||||
| // cell_call_fn = cell.__call__ if callable(cell) else cell.call | |||||
| // if generic_utils.has_arg(cell.call, 'constants'): | |||||
| // inputs, states = cell_call_fn(inputs, states, | |||||
| // constants= constants, ** kwargs) | |||||
| // else: | |||||
| // inputs, states = cell_call_fn(inputs, states, ** kwargs) | |||||
| // new_nested_states.append(states) | |||||
| // return inputs, nest.pack_sequence_as(state_size, | |||||
| // nest.flatten(new_nested_states)) | |||||
| // @tf_utils.shape_type_conversion | |||||
| // def build(self, input_shape) : | |||||
| // if isinstance(input_shape, list) : | |||||
| // input_shape = input_shape[0] | |||||
| // for cell in self.cells: | |||||
| // if isinstance(cell, Layer) and not cell.built: | |||||
| // with K.name_scope(cell.name): | |||||
| // cell.build(input_shape) | |||||
| // cell.built = True | |||||
| // if getattr(cell, 'output_size', None) is not None: | |||||
| // output_dim = cell.output_size | |||||
| // elif _is_multiple_state(cell.state_size) : | |||||
| // output_dim = cell.state_size[0] | |||||
| // else: | |||||
| // output_dim = cell.state_size | |||||
| // input_shape = tuple([input_shape[0]] + | |||||
| // tensor_shape.TensorShape(output_dim).as_list()) | |||||
| // self.built = True | |||||
| // def get_config(self) : | |||||
| // cells = [] | |||||
| // for cell in self.cells: | |||||
| // cells.append(generic_utils.serialize_keras_object(cell)) | |||||
| // config = {'cells': cells | |||||
| //} | |||||
| //base_config = super(StackedRNNCells, self).get_config() | |||||
| // return dict(list(base_config.items()) + list(config.items())) | |||||
| // @classmethod | |||||
| // def from_config(cls, config, custom_objects = None): | |||||
| // from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top | |||||
| // cells = [] | |||||
| // for cell_config in config.pop('cells'): | |||||
| // cells.append( | |||||
| // deserialize_layer(cell_config, custom_objects = custom_objects)) | |||||
| // return cls(cells, **config) | |||||
| } | |||||
| } | |||||
| @@ -36,7 +36,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
| var model = keras.Model(inputs, outputs, name: "mnist_model"); | var model = keras.Model(inputs, outputs, name: "mnist_model"); | ||||
| model.summary(); | model.summary(); | ||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Custom layer test, used in Dueling DQN | /// Custom layer test, used in Dueling DQN | ||||
| /// </summary> | /// </summary> | ||||
| @@ -45,10 +45,10 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
| { | { | ||||
| var layers = keras.layers; | var layers = keras.layers; | ||||
| var inputs = layers.Input(shape: 24); | var inputs = layers.Input(shape: 24); | ||||
| var x = layers.Dense(128, activation:"relu").Apply(inputs); | |||||
| var x = layers.Dense(128, activation: "relu").Apply(inputs); | |||||
| var value = layers.Dense(24).Apply(x); | var value = layers.Dense(24).Apply(x); | ||||
| var adv = layers.Dense(1).Apply(x); | var adv = layers.Dense(1).Apply(x); | ||||
| var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true); | var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true); | ||||
| adv = layers.Subtract().Apply((adv, mean)); | adv = layers.Subtract().Apply((adv, mean)); | ||||
| var outputs = layers.Add().Apply((value, adv)); | var outputs = layers.Add().Apply((value, adv)); | ||||
| @@ -105,9 +105,14 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| [Ignore] | |||||
| public void SimpleRNN() | public void SimpleRNN() | ||||
| { | { | ||||
| var inputs = np.random.rand(32, 10, 8).astype(np.float32); | |||||
| var simple_rnn = keras.layers.SimpleRNN(4); | |||||
| var output = simple_rnn.Apply(inputs); | |||||
| Assert.AreEqual((32, 4), output.shape); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||