diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs index 1c52e47b..9b910e17 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs @@ -5,5 +5,6 @@ namespace Tensorflow.Keras.ArgsDefinition public class StackedRNNCellsArgs : LayerArgs { public IList Cells { get; set; } + public Dictionary Kwargs { get; set; } = null; } } diff --git a/src/TensorFlowNET.Keras/Layers/RNN.cs b/src/TensorFlowNET.Keras/Layers/RNN.cs index 0c77d57f..411869e4 100644 --- a/src/TensorFlowNET.Keras/Layers/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/RNN.cs @@ -2,12 +2,18 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +// from tensorflow.python.distribute import distribution_strategy_context as ds_context; namespace Tensorflow.Keras.Layers { public class RNN : Layer { private RNNArgs args; + private object input_spec = null; // or NoneValue?? + private object state_spec = null; + private object _states = null; + private object constants_spec = null; + private int _num_constants = 0; public RNN(RNNArgs args) : base(PreConstruct(args)) { @@ -18,16 +24,13 @@ namespace Tensorflow.Keras.Layers // the input spec will be the list of specs for nested inputs, the structure // of the input_spec will be the same as the input. - //self.input_spec = None - //self.state_spec = None - //self._states = None - //self.constants_spec = None - //self._num_constants = 0 - - //if stateful: - // if ds_context.has_strategy(): - // raise ValueError('RNNs with stateful=True not yet supported with ' - // 'tf.distribute.Strategy.') + //if(stateful) + //{ + // if (ds_context.has_strategy()) // ds_context???? + // { + // throw new Exception("RNNs with stateful=True not yet supported with tf.distribute.Strategy"); + // } + //} } private static RNNArgs PreConstruct(RNNArgs args) @@ -41,16 +44,16 @@ namespace Tensorflow.Keras.Layers // false case, output from previous timestep is returned for masked timestep. var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false); - object input_shape; - var propIS = args.Kwargs.Get("input_shape", null); - var propID = args.Kwargs.Get("input_dim", null); - var propIL = args.Kwargs.Get("input_length", null); + TensorShape input_shape; + var propIS = (TensorShape)args.Kwargs.Get("input_shape", null); + var propID = (int?)args.Kwargs.Get("input_dim", null); + var propIL = (int?)args.Kwargs.Get("input_length", null); if (propIS == null && (propID != null || propIL != null)) { - input_shape = ( - propIL ?? new NoneValue(), // maybe null is needed here - propID ?? new NoneValue()); // and here + input_shape = new TensorShape( + propIL ?? -1, + propID ?? -1); args.Kwargs["input_shape"] = input_shape; } @@ -103,5 +106,14 @@ namespace Tensorflow.Keras.Layers { throw new NotImplementedException(""); } + + // Check whether the state_size contains multiple states. + public static bool _is_multiple_state(object state_size) + { + var myIndexerProperty = state_size.GetType().GetProperty("Item"); + return myIndexerProperty != null + && myIndexerProperty.GetIndexParameters().Length == 1 + && !(state_size.GetType() == typeof(TensorShape)); + } } } diff --git a/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs index c0a2371f..dad7e0af 100644 --- a/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.ComponentModel; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; @@ -8,118 +9,155 @@ namespace Tensorflow.Keras.Layers public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell { public IList Cells { get; set; } + public bool reverse_state_order; public StackedRNNCells(StackedRNNCellsArgs args) : base(args) { + if (args.Kwargs == null) + { + args.Kwargs = new Dictionary(); + } + Cells = args.Cells; - //Cells.reverse_state_order = kwargs.pop('reverse_state_order', False); - // self.reverse_state_order = kwargs.pop('reverse_state_order', False) - // if self.reverse_state_order: - // logging.warning('reverse_state_order=True in StackedRNNCells will soon ' - // 'be deprecated. Please update the code to work with the ' - // 'natural order of states if you rely on the RNN states, ' - // 'eg RNN(return_state=True).') - // super(StackedRNNCells, self).__init__(**kwargs) - throw new NotImplementedException(""); + reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false); + + if (reverse_state_order) + { + throw new WarningException("reverse_state_order=True in StackedRNNCells will soon " + + "be deprecated. Please update the code to work with the " + + "natural order of states if you rely on the RNN states, " + + "eg RNN(return_state=True)."); + } } public object state_size { get => throw new NotImplementedException(); + //@property + //def state_size(self) : + // return tuple(c.state_size for c in + // (self.cells[::- 1] if self.reverse_state_order else self.cells)) + } + + public object output_size + { + get + { + var lastCell = Cells[Cells.Count - 1]; + + if (lastCell.output_size != -1) + { + return lastCell.output_size; + } + else if (RNN._is_multiple_state(lastCell.state_size)) + { + return ((dynamic)Cells[-1].state_size)[0]; + } + else + { + return Cells[-1].state_size; + } + } + } + + public object get_initial_state() + { + throw new NotImplementedException(); + // def get_initial_state(self, inputs= None, batch_size= None, dtype= None) : + // initial_states = [] + // for cell in self.cells[::- 1] if self.reverse_state_order else self.cells: + // get_initial_state_fn = getattr(cell, 'get_initial_state', None) + // if get_initial_state_fn: + // initial_states.append(get_initial_state_fn( + // inputs=inputs, batch_size=batch_size, dtype=dtype)) + // else: + // initial_states.append(_generate_zero_filled_state_for_cell( + // cell, inputs, batch_size, dtype)) + + // return tuple(initial_states) + } + + public object call() + { + throw new NotImplementedException(); + // def call(self, inputs, states, constants= None, training= None, ** kwargs): + // # Recover per-cell states. + // state_size = (self.state_size[::- 1] + // if self.reverse_state_order else self.state_size) + // nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) + + // # Call the cells in order and store the returned states. + // new_nested_states = [] + // for cell, states in zip(self.cells, nested_states) : + // states = states if nest.is_nested(states) else [states] + //# TF cell does not wrap the state into list when there is only one state. + // is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None + // states = states[0] if len(states) == 1 and is_tf_rnn_cell else states + // if generic_utils.has_arg(cell.call, 'training'): + // kwargs['training'] = training + // else: + // kwargs.pop('training', None) + // # Use the __call__ function for callable objects, eg layers, so that it + // # will have the proper name scopes for the ops, etc. + // cell_call_fn = cell.__call__ if callable(cell) else cell.call + // if generic_utils.has_arg(cell.call, 'constants'): + // inputs, states = cell_call_fn(inputs, states, + // constants= constants, ** kwargs) + // else: + // inputs, states = cell_call_fn(inputs, states, ** kwargs) + // new_nested_states.append(states) + + // return inputs, nest.pack_sequence_as(state_size, + // nest.flatten(new_nested_states)) + } + + public void build() + { + throw new NotImplementedException(); + // @tf_utils.shape_type_conversion + // def build(self, input_shape) : + // if isinstance(input_shape, list) : + // input_shape = input_shape[0] + // for cell in self.cells: + // if isinstance(cell, Layer) and not cell.built: + // with K.name_scope(cell.name): + // cell.build(input_shape) + // cell.built = True + // if getattr(cell, 'output_size', None) is not None: + // output_dim = cell.output_size + // elif _is_multiple_state(cell.state_size) : + // output_dim = cell.state_size[0] + // else: + // output_dim = cell.state_size + // input_shape = tuple([input_shape[0]] + + // tensor_shape.TensorShape(output_dim).as_list()) + // self.built = True } - //@property - //def state_size(self) : - // return tuple(c.state_size for c in - // (self.cells[::- 1] if self.reverse_state_order else self.cells)) - - // @property - // def output_size(self) : - // if getattr(self.cells[-1], 'output_size', None) is not None: - // return self.cells[-1].output_size - // elif _is_multiple_state(self.cells[-1].state_size) : - // return self.cells[-1].state_size[0] - // else: - // return self.cells[-1].state_size - - // def get_initial_state(self, inputs= None, batch_size= None, dtype= None) : - // initial_states = [] - // for cell in self.cells[::- 1] if self.reverse_state_order else self.cells: - // get_initial_state_fn = getattr(cell, 'get_initial_state', None) - // if get_initial_state_fn: - // initial_states.append(get_initial_state_fn( - // inputs=inputs, batch_size=batch_size, dtype=dtype)) - // else: - // initial_states.append(_generate_zero_filled_state_for_cell( - // cell, inputs, batch_size, dtype)) - - // return tuple(initial_states) - - // def call(self, inputs, states, constants= None, training= None, ** kwargs): - // # Recover per-cell states. - // state_size = (self.state_size[::- 1] - // if self.reverse_state_order else self.state_size) - // nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) - - // # Call the cells in order and store the returned states. - // new_nested_states = [] - // for cell, states in zip(self.cells, nested_states) : - // states = states if nest.is_nested(states) else [states] - //# TF cell does not wrap the state into list when there is only one state. - // is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None - // states = states[0] if len(states) == 1 and is_tf_rnn_cell else states - // if generic_utils.has_arg(cell.call, 'training'): - // kwargs['training'] = training - // else: - // kwargs.pop('training', None) - // # Use the __call__ function for callable objects, eg layers, so that it - // # will have the proper name scopes for the ops, etc. - // cell_call_fn = cell.__call__ if callable(cell) else cell.call - // if generic_utils.has_arg(cell.call, 'constants'): - // inputs, states = cell_call_fn(inputs, states, - // constants= constants, ** kwargs) - // else: - // inputs, states = cell_call_fn(inputs, states, ** kwargs) - // new_nested_states.append(states) - - // return inputs, nest.pack_sequence_as(state_size, - // nest.flatten(new_nested_states)) - - // @tf_utils.shape_type_conversion - // def build(self, input_shape) : - // if isinstance(input_shape, list) : - // input_shape = input_shape[0] - // for cell in self.cells: - // if isinstance(cell, Layer) and not cell.built: - // with K.name_scope(cell.name): - // cell.build(input_shape) - // cell.built = True - // if getattr(cell, 'output_size', None) is not None: - // output_dim = cell.output_size - // elif _is_multiple_state(cell.state_size) : - // output_dim = cell.state_size[0] - // else: - // output_dim = cell.state_size - // input_shape = tuple([input_shape[0]] + - // tensor_shape.TensorShape(output_dim).as_list()) - // self.built = True - - // def get_config(self) : - // cells = [] - // for cell in self.cells: - // cells.append(generic_utils.serialize_keras_object(cell)) - // config = {'cells': cells - //} - //base_config = super(StackedRNNCells, self).get_config() - // return dict(list(base_config.items()) + list(config.items())) - - // @classmethod - // def from_config(cls, config, custom_objects = None): - // from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top - // cells = [] - // for cell_config in config.pop('cells'): - // cells.append( - // deserialize_layer(cell_config, custom_objects = custom_objects)) - // return cls(cells, **config) + public override LayerArgs get_config() + { + throw new NotImplementedException(); + //def get_config(self): + // cells = [] + // for cell in self.cells: + // cells.append(generic_utils.serialize_keras_object(cell)) + // config = {'cells': cells} + // base_config = super(StackedRNNCells, self).get_config() + // return dict(list(base_config.items()) + list(config.items())) + } + + + public void from_config() + { + throw new NotImplementedException(); + // @classmethod + // def from_config(cls, config, custom_objects = None): + // from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + // cells = [] + // for cell_config in config.pop('cells'): + // cells.append( + // deserialize_layer(cell_config, custom_objects = custom_objects)) + // return cls(cells, **config) + } } }