using System; using System.ComponentModel; using System.Linq; using Tensorflow.Common.Extensions; using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Layers { public class StackedRNNCells : Layer, IRnnCell { public IList Cells { get; set; } public bool _reverse_state_order; public StackedRNNCells(IEnumerable cells, StackedRNNCellsArgs args) : base(args) { Cells = cells.ToList(); _reverse_state_order = args.ReverseStateOrder; if (_reverse_state_order) { throw new WarningException("reverse_state_order=True in StackedRNNCells will soon " + "be deprecated. Please update the code to work with the " + "natural order of states if you rely on the RNN states, " + "eg RNN(return_state=True)."); } } public bool SupportOptionalArgs => false; public INestStructure StateSize { get { if (_reverse_state_order) { var state_sizes = Cells.Reverse().Select(cell => cell.StateSize); return new Nest(state_sizes); } else { var state_sizes = Cells.Select(cell => cell.StateSize); return new Nest(state_sizes); } } } public INestStructure OutputSize { get { var lastCell = Cells.Last(); if(lastCell.OutputSize is not null) { return lastCell.OutputSize; } else if (RnnUtils.is_multiple_state(lastCell.StateSize)) { return new NestNode(lastCell.StateSize.Flatten().First()); } else { return lastCell.StateSize; } } } public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) { var cells = _reverse_state_order ? Cells.Reverse() : Cells; List initial_states = new List(); foreach (var cell in cells) { initial_states.Add(cell.GetInitialState(inputs, batch_size, dtype)); } return new Tensors(initial_states); } protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) { // Recover per-cell states. var state_size = _reverse_state_order ? new NestList(StateSize.Flatten().Reverse()) : StateSize; var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray()); var new_nest_states = Nest.Empty; // Call the cells in order and store the returned states. foreach (var (cell, internal_states) in zip(Cells, nested_states)) { RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; Tensors? constants = rnn_optional_args?.Constants; Tensors new_states; (inputs, new_states) = cell.Apply(inputs, internal_states, optional_args: new RnnOptionalArgs() { Constants = constants }); new_nest_states = new_nest_states.MergeWith(new_states); } return Tensors.FromNest((inputs, Nest.PackSequenceAs(state_size, Nest.Flatten(new_nest_states).ToArray()))); } public override void build(KerasShapesWrapper input_shape) { var shape = input_shape.ToSingleShape(); foreach(var cell in Cells) { if(cell is Layer layer && !layer.Built) { // ignored the name scope. layer.build(shape); layer.Built = true; } INestStructure output_dim; if(cell.OutputSize is not null) { output_dim = cell.OutputSize; } else if (RnnUtils.is_multiple_state(cell.StateSize)) { output_dim = new NestNode(cell.StateSize.Flatten().First()); } else { output_dim = cell.StateSize; } shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.Flatten()).ToArray()); } this.Built = true; } public override IKerasConfig get_config() { throw new NotImplementedException(); //def get_config(self): // cells = [] // for cell in self.cells: // cells.append(generic_utils.serialize_keras_object(cell)) // config = {'cells': cells} // base_config = super(StackedRNNCells, self).get_config() // return dict(list(base_config.items()) + list(config.items())) } public void from_config() { throw new NotImplementedException(); // @classmethod // def from_config(cls, config, custom_objects = None): // from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top // cells = [] // for cell_config in config.pop('cells'): // cells.append( // deserialize_layer(cell_config, custom_objects = custom_objects)) // return cls(cells, **config) } } }