| @@ -7,21 +7,6 @@ namespace Tensorflow.Common.Types | |||
| { | |||
| public class GeneralizedTensorShape: Nest<Shape> | |||
| { | |||
| ////public TensorShapeConfig[] Shapes { get; set; } | |||
| ///// <summary> | |||
| ///// create a single-dim generalized Tensor shape. | |||
| ///// </summary> | |||
| ///// <param name="dim"></param> | |||
| //public GeneralizedTensorShape(int dim, int size = 1) | |||
| //{ | |||
| // var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; | |||
| // Shapes = Enumerable.Repeat(elem, size).ToArray(); | |||
| // //Shapes = new TensorShapeConfig[size]; | |||
| // //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } }); | |||
| // //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } }); | |||
| // ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; | |||
| //} | |||
| public GeneralizedTensorShape(Shape value, string? name = null) | |||
| { | |||
| NodeValue = value; | |||
| @@ -15,7 +15,12 @@ namespace Tensorflow.Common.Types | |||
| public int ShallowNestedCount => Values.Count; | |||
| public int TotalNestedCount => Values.Count; | |||
| public NestList(params T[] values) | |||
| { | |||
| Values = new List<T>(values); | |||
| } | |||
| public NestList(IEnumerable<T> values) | |||
| { | |||
| Values = new List<T>(values); | |||
| @@ -10,11 +10,11 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| /// <summary> | |||
| /// If the derived class tends to not implement it, please return null. | |||
| /// </summary> | |||
| GeneralizedTensorShape? StateSize { get; } | |||
| INestStructure<long>? StateSize { get; } | |||
| /// <summary> | |||
| /// If the derived class tends to not implement it, please return null. | |||
| /// </summary> | |||
| GeneralizedTensorShape? OutputSize { get; } | |||
| INestStructure<long>? OutputSize { get; } | |||
| /// <summary> | |||
| /// Whether the optional RNN args are supported when appying the layer. | |||
| /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. | |||
| @@ -19,13 +19,14 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.Saving.Common; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow | |||
| { | |||
| [JsonConverter(typeof(CustomizedShapeJsonConverter))] | |||
| public class Shape | |||
| public class Shape : INestStructure<long> | |||
| { | |||
| public int ndim => _dims == null ? -1 : _dims.Length; | |||
| long[] _dims; | |||
| @@ -41,6 +42,27 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public NestType NestType => NestType.List; | |||
| public int ShallowNestedCount => ndim; | |||
| /// <summary> | |||
| /// The total item count of depth 1 of the nested structure. | |||
| /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||
| /// </summary> | |||
| public int TotalNestedCount => ndim; | |||
| public IEnumerable<long> Flatten() => dims.Select(x => x); | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func) | |||
| { | |||
| return new NestList<TOut>(dims.Select(x => func(x))); | |||
| } | |||
| public Nest<long> AsNest() | |||
| { | |||
| return new NestList<long>(Flatten()).AsNest(); | |||
| } | |||
| #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | |||
| public int Length => ndim; | |||
| public long[] Slice(int start, int length) | |||
| @@ -185,8 +185,8 @@ namespace Tensorflow | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public GeneralizedTensorShape StateSize => throw new NotImplementedException(); | |||
| public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); | |||
| public INestStructure<long> StateSize => throw new NotImplementedException(); | |||
| public INestStructure<long> OutputSize => throw new NotImplementedException(); | |||
| public bool IsTFRnnCell => throw new NotImplementedException(); | |||
| public bool SupportOptionalArgs => throw new NotImplementedException(); | |||
| } | |||
| @@ -18,8 +18,8 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| } | |||
| public abstract GeneralizedTensorShape StateSize { get; } | |||
| public abstract GeneralizedTensorShape OutputSize { get; } | |||
| public abstract INestStructure<long> StateSize { get; } | |||
| public abstract INestStructure<long> OutputSize { get; } | |||
| public abstract bool SupportOptionalArgs { get; } | |||
| public virtual Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype) | |||
| { | |||
| @@ -22,13 +22,11 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| IVariableV1 _recurrent_kernel; | |||
| IInitializer _bias_initializer; | |||
| IVariableV1 _bias; | |||
| GeneralizedTensorShape _state_size; | |||
| GeneralizedTensorShape _output_size; | |||
| public override GeneralizedTensorShape StateSize => _state_size; | |||
| INestStructure<long> _state_size; | |||
| INestStructure<long> _output_size; | |||
| public override INestStructure<long> StateSize => _state_size; | |||
| public override GeneralizedTensorShape OutputSize => _output_size; | |||
| public override bool IsTFRnnCell => true; | |||
| public override INestStructure<long> OutputSize => _output_size; | |||
| public override bool SupportOptionalArgs => false; | |||
| public LSTMCell(LSTMCellArgs args) | |||
| @@ -49,10 +47,8 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| _args.Implementation = 1; | |||
| } | |||
| _state_size = new GeneralizedTensorShape(_args.Units, 2); | |||
| _output_size = new GeneralizedTensorShape(_args.Units); | |||
| _state_size = new NestList<long>(_args.Units, _args.Units); | |||
| _output_size = new NestNode<long>(_args.Units); | |||
| } | |||
| public override void build(KerasShapesWrapper input_shape) | |||
| @@ -229,11 +225,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| var o = _args.RecurrentActivation.Apply(z3); | |||
| return new Tensors(c, o); | |||
| } | |||
| public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) | |||
| { | |||
| return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); | |||
| } | |||
| } | |||
| @@ -86,7 +86,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| set { _states = value; } | |||
| } | |||
| private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape) | |||
| private INestStructure<Shape> compute_output_shape(Shape input_shape) | |||
| { | |||
| var batch = input_shape[0]; | |||
| var time_step = input_shape[1]; | |||
| @@ -96,13 +96,15 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| } | |||
| // state_size is a array of ints or a positive integer | |||
| var state_size = Cell.StateSize.ToSingleShape(); | |||
| var state_size = Cell.StateSize; | |||
| if(state_size?.TotalNestedCount == 1) | |||
| { | |||
| state_size = new NestList<long>(state_size.Flatten().First()); | |||
| } | |||
| // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor | |||
| Func<Shape, Shape> _get_output_shape; | |||
| _get_output_shape = (flat_output_size) => | |||
| Func<long, Shape> _get_output_shape = (flat_output_size) => | |||
| { | |||
| var output_dim = flat_output_size.as_int_list(); | |||
| var output_dim = new Shape(flat_output_size).as_int_list(); | |||
| Shape output_shape; | |||
| if (_args.ReturnSequences) | |||
| { | |||
| @@ -125,31 +127,28 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| Type type = Cell.GetType(); | |||
| PropertyInfo output_size_info = type.GetProperty("output_size"); | |||
| Shape output_shape; | |||
| INestStructure<Shape> output_shape; | |||
| if (output_size_info != null) | |||
| { | |||
| output_shape = nest.map_structure(_get_output_shape, Cell.OutputSize.ToSingleShape()); | |||
| // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 | |||
| output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); | |||
| output_shape = Nest.MapStructure(_get_output_shape, Cell.OutputSize); | |||
| } | |||
| else | |||
| { | |||
| output_shape = _get_output_shape(state_size); | |||
| output_shape = new NestNode<Shape>(_get_output_shape(state_size.Flatten().First())); | |||
| } | |||
| if (_args.ReturnState) | |||
| { | |||
| Func<Shape, Shape> _get_state_shape; | |||
| _get_state_shape = (flat_state) => | |||
| Func<long, Shape> _get_state_shape = (flat_state) => | |||
| { | |||
| var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); | |||
| var state_shape = new int[] { (int)batch }.concat(new Shape(flat_state).as_int_list()); | |||
| return new Shape(state_shape); | |||
| }; | |||
| var state_shape = _get_state_shape(state_size); | |||
| var state_shape = Nest.MapStructure(_get_state_shape, state_size); | |||
| return new List<Shape> { output_shape, state_shape }; | |||
| return new Nest<Shape>(new[] { output_shape, state_shape } ); | |||
| } | |||
| else | |||
| { | |||
| @@ -435,7 +434,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| tmp.add(tf.math.count_nonzero(s.Single())); | |||
| } | |||
| var non_zero_count = tf.add_n(tmp); | |||
| //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); | |||
| initial_state = tf.cond(non_zero_count > 0, States, initial_state); | |||
| if ((int)non_zero_count.numpy() > 0) | |||
| { | |||
| initial_state = States; | |||
| @@ -445,16 +444,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| initial_state = States; | |||
| } | |||
| // TODO(Wanglongzhi2001), | |||
| // initial_state = tf.nest.map_structure( | |||
| //# When the layer has a inferred dtype, use the dtype from the | |||
| //# cell. | |||
| // lambda v: tf.cast( | |||
| // v, self.compute_dtype or self.cell.compute_dtype | |||
| // ), | |||
| // initial_state, | |||
| // ) | |||
| //initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state); | |||
| } | |||
| else if (initial_state is null) | |||
| { | |||
| @@ -24,11 +24,11 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| IVariableV1 _kernel; | |||
| IVariableV1 _recurrent_kernel; | |||
| IVariableV1 _bias; | |||
| GeneralizedTensorShape _state_size; | |||
| GeneralizedTensorShape _output_size; | |||
| INestStructure<long> _state_size; | |||
| INestStructure<long> _output_size; | |||
| public override GeneralizedTensorShape StateSize => _state_size; | |||
| public override GeneralizedTensorShape OutputSize => _output_size; | |||
| public override INestStructure<long> StateSize => _state_size; | |||
| public override INestStructure<long> OutputSize => _output_size; | |||
| public override bool SupportOptionalArgs => false; | |||
| public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) | |||
| @@ -41,8 +41,8 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| } | |||
| this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); | |||
| this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); | |||
| _state_size = new GeneralizedTensorShape(args.Units); | |||
| _output_size = new GeneralizedTensorShape(args.Units); | |||
| _state_size = new NestNode<long>(args.Units); | |||
| _output_size = new NestNode<long>(args.Units); | |||
| } | |||
| public override void build(KerasShapesWrapper input_shape) | |||
| @@ -1,10 +1,8 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.ComponentModel; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Extensions; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -38,24 +36,24 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| public bool SupportOptionalArgs => false; | |||
| public GeneralizedTensorShape StateSize | |||
| public INestStructure<long> StateSize | |||
| { | |||
| get | |||
| { | |||
| if (_reverse_state_order) | |||
| { | |||
| var state_sizes = Cells.Reverse().Select(cell => cell.StateSize); | |||
| return new GeneralizedTensorShape(new Nest<Shape>(state_sizes.Select(s => new Nest<Shape>(s)))); | |||
| return new Nest<long>(state_sizes); | |||
| } | |||
| else | |||
| { | |||
| var state_sizes = Cells.Select(cell => cell.StateSize); | |||
| return new GeneralizedTensorShape(new Nest<Shape>(state_sizes.Select(s => new Nest<Shape>(s)))); | |||
| return new Nest<long>(state_sizes); | |||
| } | |||
| } | |||
| } | |||
| public GeneralizedTensorShape OutputSize | |||
| public INestStructure<long> OutputSize | |||
| { | |||
| get | |||
| { | |||
| @@ -66,7 +64,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| } | |||
| else if (RnnUtils.is_multiple_state(lastCell.StateSize)) | |||
| { | |||
| return lastCell.StateSize.First(); | |||
| return new NestNode<long>(lastCell.StateSize.Flatten().First()); | |||
| } | |||
| else | |||
| { | |||
| @@ -89,7 +87,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| // Recover per-cell states. | |||
| var state_size = _reverse_state_order ? new GeneralizedTensorShape(StateSize.Reverse()) : StateSize; | |||
| var state_size = _reverse_state_order ? new NestList<long>(StateSize.Flatten().Reverse()) : StateSize; | |||
| var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray()); | |||
| var new_nest_states = Nest<Tensor>.Empty; | |||
| @@ -118,20 +116,20 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| layer.build(shape); | |||
| layer.Built = true; | |||
| } | |||
| GeneralizedTensorShape output_dim; | |||
| INestStructure<long> output_dim; | |||
| if(cell.OutputSize is not null) | |||
| { | |||
| output_dim = cell.OutputSize; | |||
| } | |||
| else if (RnnUtils.is_multiple_state(cell.StateSize)) | |||
| { | |||
| output_dim = cell.StateSize.First(); | |||
| output_dim = new NestNode<long>(cell.StateSize.Flatten().First()); | |||
| } | |||
| else | |||
| { | |||
| output_dim = cell.StateSize; | |||
| } | |||
| shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.ToSingleShape().dims).ToArray()); | |||
| shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.Flatten()).ToArray()); | |||
| } | |||
| this.Built = true; | |||
| } | |||
| @@ -10,12 +10,11 @@ namespace Tensorflow.Keras.Utils | |||
| { | |||
| internal static class RnnUtils | |||
| { | |||
| internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) | |||
| internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, INestStructure<long> state_size, TF_DataType dtype) | |||
| { | |||
| Func<GeneralizedTensorShape, Tensor> create_zeros; | |||
| create_zeros = (GeneralizedTensorShape unnested_state_size) => | |||
| Func<long, Tensor> create_zeros = (unnested_state_size) => | |||
| { | |||
| var flat_dims = unnested_state_size.ToSingleShape().dims; | |||
| var flat_dims = new Shape(unnested_state_size).dims; | |||
| var init_state_size = new Tensor[] { batch_size_tensor }. | |||
| Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray(); | |||
| return array_ops.zeros(init_state_size, dtype: dtype); | |||
| @@ -24,11 +23,11 @@ namespace Tensorflow.Keras.Utils | |||
| // TODO(Rinne): map structure with nested tensors. | |||
| if(state_size.TotalNestedCount > 1) | |||
| { | |||
| return new Tensors(state_size.Flatten().Select(s => create_zeros(new GeneralizedTensorShape(s))).ToArray()); | |||
| return new Tensors(state_size.Flatten().Select(s => create_zeros(s)).ToArray()); | |||
| } | |||
| else | |||
| { | |||
| return create_zeros(state_size); | |||
| return create_zeros(state_size.Flatten().First()); | |||
| } | |||
| } | |||
| @@ -96,7 +95,7 @@ namespace Tensorflow.Keras.Utils | |||
| /// </summary> | |||
| /// <param name="state_size"></param> | |||
| /// <returns></returns> | |||
| public static bool is_multiple_state(GeneralizedTensorShape state_size) | |||
| public static bool is_multiple_state(INestStructure<long> state_size) | |||
| { | |||
| return state_size.TotalNestedCount > 1; | |||
| } | |||