diff --git a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs index 40190315..986136f4 100644 --- a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs +++ b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs @@ -7,21 +7,6 @@ namespace Tensorflow.Common.Types { public class GeneralizedTensorShape: Nest { - ////public TensorShapeConfig[] Shapes { get; set; } - ///// - ///// create a single-dim generalized Tensor shape. - ///// - ///// - //public GeneralizedTensorShape(int dim, int size = 1) - //{ - // var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; - // Shapes = Enumerable.Repeat(elem, size).ToArray(); - // //Shapes = new TensorShapeConfig[size]; - // //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } }); - // //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } }); - // ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; - //} - public GeneralizedTensorShape(Shape value, string? name = null) { NodeValue = value; diff --git a/src/TensorFlowNET.Core/Common/Types/NestList.cs b/src/TensorFlowNET.Core/Common/Types/NestList.cs index e38675da..1e0d272b 100644 --- a/src/TensorFlowNET.Core/Common/Types/NestList.cs +++ b/src/TensorFlowNET.Core/Common/Types/NestList.cs @@ -15,7 +15,12 @@ namespace Tensorflow.Common.Types public int ShallowNestedCount => Values.Count; public int TotalNestedCount => Values.Count; - + + public NestList(params T[] values) + { + Values = new List(values); + } + public NestList(IEnumerable values) { Values = new List(values); diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs index 8614391a..8d6fbc97 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs @@ -10,11 +10,11 @@ namespace Tensorflow.Keras.Layers.Rnn /// /// If the derived class tends to not implement it, please return null. /// - GeneralizedTensorShape? StateSize { get; } + INestStructure? StateSize { get; } /// /// If the derived class tends to not implement it, please return null. /// - GeneralizedTensorShape? OutputSize { get; } + INestStructure? OutputSize { get; } /// /// Whether the optional RNN args are supported when appying the layer. /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index c339f12d..cbbf66b4 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -19,13 +19,14 @@ using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Common.Types; using Tensorflow.Keras.Saving.Common; using Tensorflow.NumPy; namespace Tensorflow { [JsonConverter(typeof(CustomizedShapeJsonConverter))] - public class Shape + public class Shape : INestStructure { public int ndim => _dims == null ? -1 : _dims.Length; long[] _dims; @@ -41,6 +42,27 @@ namespace Tensorflow } } + public NestType NestType => NestType.List; + + public int ShallowNestedCount => ndim; + /// + /// The total item count of depth 1 of the nested structure. + /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. + /// + public int TotalNestedCount => ndim; + + public IEnumerable Flatten() => dims.Select(x => x); + + public INestStructure MapStructure(Func func) + { + return new NestList(dims.Select(x => func(x))); + } + + public Nest AsNest() + { + return new NestList(Flatten()).AsNest(); + } + #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges public int Length => ndim; public long[] Slice(int start, int length) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index b651089a..e488c47e 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -185,8 +185,8 @@ namespace Tensorflow { throw new NotImplementedException(); } - public GeneralizedTensorShape StateSize => throw new NotImplementedException(); - public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); + public INestStructure StateSize => throw new NotImplementedException(); + public INestStructure OutputSize => throw new NotImplementedException(); public bool IsTFRnnCell => throw new NotImplementedException(); public bool SupportOptionalArgs => throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs index 1cc36d34..75feb8ea 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs @@ -18,8 +18,8 @@ namespace Tensorflow.Keras.Layers.Rnn } - public abstract GeneralizedTensorShape StateSize { get; } - public abstract GeneralizedTensorShape OutputSize { get; } + public abstract INestStructure StateSize { get; } + public abstract INestStructure OutputSize { get; } public abstract bool SupportOptionalArgs { get; } public virtual Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype) { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs index 94d98e13..17042767 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs @@ -22,13 +22,11 @@ namespace Tensorflow.Keras.Layers.Rnn IVariableV1 _recurrent_kernel; IInitializer _bias_initializer; IVariableV1 _bias; - GeneralizedTensorShape _state_size; - GeneralizedTensorShape _output_size; - public override GeneralizedTensorShape StateSize => _state_size; + INestStructure _state_size; + INestStructure _output_size; + public override INestStructure StateSize => _state_size; - public override GeneralizedTensorShape OutputSize => _output_size; - - public override bool IsTFRnnCell => true; + public override INestStructure OutputSize => _output_size; public override bool SupportOptionalArgs => false; public LSTMCell(LSTMCellArgs args) @@ -49,10 +47,8 @@ namespace Tensorflow.Keras.Layers.Rnn _args.Implementation = 1; } - _state_size = new GeneralizedTensorShape(_args.Units, 2); - _output_size = new GeneralizedTensorShape(_args.Units); - - + _state_size = new NestList(_args.Units, _args.Units); + _output_size = new NestNode(_args.Units); } public override void build(KerasShapesWrapper input_shape) @@ -229,11 +225,6 @@ namespace Tensorflow.Keras.Layers.Rnn var o = _args.RecurrentActivation.Apply(z3); return new Tensors(c, o); } - - public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) - { - return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); - } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index f99bc23a..0aeacc25 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -86,7 +86,7 @@ namespace Tensorflow.Keras.Layers.Rnn set { _states = value; } } - private OneOf> compute_output_shape(Shape input_shape) + private INestStructure compute_output_shape(Shape input_shape) { var batch = input_shape[0]; var time_step = input_shape[1]; @@ -96,13 +96,15 @@ namespace Tensorflow.Keras.Layers.Rnn } // state_size is a array of ints or a positive integer - var state_size = Cell.StateSize.ToSingleShape(); + var state_size = Cell.StateSize; + if(state_size?.TotalNestedCount == 1) + { + state_size = new NestList(state_size.Flatten().First()); + } - // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor - Func _get_output_shape; - _get_output_shape = (flat_output_size) => + Func _get_output_shape = (flat_output_size) => { - var output_dim = flat_output_size.as_int_list(); + var output_dim = new Shape(flat_output_size).as_int_list(); Shape output_shape; if (_args.ReturnSequences) { @@ -125,31 +127,28 @@ namespace Tensorflow.Keras.Layers.Rnn Type type = Cell.GetType(); PropertyInfo output_size_info = type.GetProperty("output_size"); - Shape output_shape; + INestStructure output_shape; if (output_size_info != null) { - output_shape = nest.map_structure(_get_output_shape, Cell.OutputSize.ToSingleShape()); - // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 - output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); + output_shape = Nest.MapStructure(_get_output_shape, Cell.OutputSize); } else { - output_shape = _get_output_shape(state_size); + output_shape = new NestNode(_get_output_shape(state_size.Flatten().First())); } if (_args.ReturnState) { - Func _get_state_shape; - _get_state_shape = (flat_state) => + Func _get_state_shape = (flat_state) => { - var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); + var state_shape = new int[] { (int)batch }.concat(new Shape(flat_state).as_int_list()); return new Shape(state_shape); }; - var state_shape = _get_state_shape(state_size); + var state_shape = Nest.MapStructure(_get_state_shape, state_size); - return new List { output_shape, state_shape }; + return new Nest(new[] { output_shape, state_shape } ); } else { @@ -435,7 +434,7 @@ namespace Tensorflow.Keras.Layers.Rnn tmp.add(tf.math.count_nonzero(s.Single())); } var non_zero_count = tf.add_n(tmp); - //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); + initial_state = tf.cond(non_zero_count > 0, States, initial_state); if ((int)non_zero_count.numpy() > 0) { initial_state = States; @@ -445,16 +444,7 @@ namespace Tensorflow.Keras.Layers.Rnn { initial_state = States; } - // TODO(Wanglongzhi2001), -// initial_state = tf.nest.map_structure( -//# When the layer has a inferred dtype, use the dtype from the -//# cell. -// lambda v: tf.cast( -// v, self.compute_dtype or self.cell.compute_dtype -// ), -// initial_state, -// ) - + //initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state); } else if (initial_state is null) { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index d318dc45..8fdc598e 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -24,11 +24,11 @@ namespace Tensorflow.Keras.Layers.Rnn IVariableV1 _kernel; IVariableV1 _recurrent_kernel; IVariableV1 _bias; - GeneralizedTensorShape _state_size; - GeneralizedTensorShape _output_size; + INestStructure _state_size; + INestStructure _output_size; - public override GeneralizedTensorShape StateSize => _state_size; - public override GeneralizedTensorShape OutputSize => _output_size; + public override INestStructure StateSize => _state_size; + public override INestStructure OutputSize => _output_size; public override bool SupportOptionalArgs => false; public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) @@ -41,8 +41,8 @@ namespace Tensorflow.Keras.Layers.Rnn } this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); - _state_size = new GeneralizedTensorShape(args.Units); - _output_size = new GeneralizedTensorShape(args.Units); + _state_size = new NestNode(args.Units); + _output_size = new NestNode(args.Units); } public override void build(KerasShapesWrapper input_shape) diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index fb74d6d2..3e7b227c 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -1,10 +1,8 @@ using System; -using System.Collections.Generic; using System.ComponentModel; using System.Linq; using Tensorflow.Common.Extensions; using Tensorflow.Common.Types; -using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; @@ -38,24 +36,24 @@ namespace Tensorflow.Keras.Layers.Rnn public bool SupportOptionalArgs => false; - public GeneralizedTensorShape StateSize + public INestStructure StateSize { get { if (_reverse_state_order) { var state_sizes = Cells.Reverse().Select(cell => cell.StateSize); - return new GeneralizedTensorShape(new Nest(state_sizes.Select(s => new Nest(s)))); + return new Nest(state_sizes); } else { var state_sizes = Cells.Select(cell => cell.StateSize); - return new GeneralizedTensorShape(new Nest(state_sizes.Select(s => new Nest(s)))); + return new Nest(state_sizes); } } } - public GeneralizedTensorShape OutputSize + public INestStructure OutputSize { get { @@ -66,7 +64,7 @@ namespace Tensorflow.Keras.Layers.Rnn } else if (RnnUtils.is_multiple_state(lastCell.StateSize)) { - return lastCell.StateSize.First(); + return new NestNode(lastCell.StateSize.Flatten().First()); } else { @@ -89,7 +87,7 @@ namespace Tensorflow.Keras.Layers.Rnn protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) { // Recover per-cell states. - var state_size = _reverse_state_order ? new GeneralizedTensorShape(StateSize.Reverse()) : StateSize; + var state_size = _reverse_state_order ? new NestList(StateSize.Flatten().Reverse()) : StateSize; var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray()); var new_nest_states = Nest.Empty; @@ -118,20 +116,20 @@ namespace Tensorflow.Keras.Layers.Rnn layer.build(shape); layer.Built = true; } - GeneralizedTensorShape output_dim; + INestStructure output_dim; if(cell.OutputSize is not null) { output_dim = cell.OutputSize; } else if (RnnUtils.is_multiple_state(cell.StateSize)) { - output_dim = cell.StateSize.First(); + output_dim = new NestNode(cell.StateSize.Flatten().First()); } else { output_dim = cell.StateSize; } - shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.ToSingleShape().dims).ToArray()); + shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.Flatten()).ToArray()); } this.Built = true; } diff --git a/src/TensorFlowNET.Keras/Utils/RnnUtils.cs b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs index 7ff3f9fb..e8700c1f 100644 --- a/src/TensorFlowNET.Keras/Utils/RnnUtils.cs +++ b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs @@ -10,12 +10,11 @@ namespace Tensorflow.Keras.Utils { internal static class RnnUtils { - internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) + internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, INestStructure state_size, TF_DataType dtype) { - Func create_zeros; - create_zeros = (GeneralizedTensorShape unnested_state_size) => + Func create_zeros = (unnested_state_size) => { - var flat_dims = unnested_state_size.ToSingleShape().dims; + var flat_dims = new Shape(unnested_state_size).dims; var init_state_size = new Tensor[] { batch_size_tensor }. Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray(); return array_ops.zeros(init_state_size, dtype: dtype); @@ -24,11 +23,11 @@ namespace Tensorflow.Keras.Utils // TODO(Rinne): map structure with nested tensors. if(state_size.TotalNestedCount > 1) { - return new Tensors(state_size.Flatten().Select(s => create_zeros(new GeneralizedTensorShape(s))).ToArray()); + return new Tensors(state_size.Flatten().Select(s => create_zeros(s)).ToArray()); } else { - return create_zeros(state_size); + return create_zeros(state_size.Flatten().First()); } } @@ -96,7 +95,7 @@ namespace Tensorflow.Keras.Utils /// /// /// - public static bool is_multiple_state(GeneralizedTensorShape state_size) + public static bool is_multiple_state(INestStructure state_size) { return state_size.TotalNestedCount > 1; }