Browse Source

fix: error after merging LSTM support.

tags/v0.110.0-LSTM-Model
Yaohui Liu 2 years ago
parent
commit
6b30902ee8
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
11 changed files with 79 additions and 89 deletions
  1. +0
    -15
      src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
  2. +6
    -1
      src/TensorFlowNET.Core/Common/Types/NestList.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
  4. +23
    -1
      src/TensorFlowNET.Core/Numpy/Shape.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  6. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
  7. +6
    -15
      src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs
  8. +17
    -27
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  9. +6
    -6
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  10. +9
    -11
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  11. +6
    -7
      src/TensorFlowNET.Keras/Utils/RnnUtils.cs

+ 0
- 15
src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs View File

@@ -7,21 +7,6 @@ namespace Tensorflow.Common.Types
{
public class GeneralizedTensorShape: Nest<Shape>
{
////public TensorShapeConfig[] Shapes { get; set; }
///// <summary>
///// create a single-dim generalized Tensor shape.
///// </summary>
///// <param name="dim"></param>
//public GeneralizedTensorShape(int dim, int size = 1)
//{
// var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
// Shapes = Enumerable.Repeat(elem, size).ToArray();
// //Shapes = new TensorShapeConfig[size];
// //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
// //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
// ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
//}

public GeneralizedTensorShape(Shape value, string? name = null)
{
NodeValue = value;


+ 6
- 1
src/TensorFlowNET.Core/Common/Types/NestList.cs View File

@@ -15,7 +15,12 @@ namespace Tensorflow.Common.Types
public int ShallowNestedCount => Values.Count;

public int TotalNestedCount => Values.Count;

public NestList(params T[] values)
{
Values = new List<T>(values);
}

public NestList(IEnumerable<T> values)
{
Values = new List<T>(values);


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs View File

@@ -10,11 +10,11 @@ namespace Tensorflow.Keras.Layers.Rnn
/// <summary>
/// If the derived class tends to not implement it, please return null.
/// </summary>
GeneralizedTensorShape? StateSize { get; }
INestStructure<long>? StateSize { get; }
/// <summary>
/// If the derived class tends to not implement it, please return null.
/// </summary>
GeneralizedTensorShape? OutputSize { get; }
INestStructure<long>? OutputSize { get; }
/// <summary>
/// Whether the optional RNN args are supported when appying the layer.
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.


+ 23
- 1
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -19,13 +19,14 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Saving.Common;
using Tensorflow.NumPy;

namespace Tensorflow
{
[JsonConverter(typeof(CustomizedShapeJsonConverter))]
public class Shape
public class Shape : INestStructure<long>
{
public int ndim => _dims == null ? -1 : _dims.Length;
long[] _dims;
@@ -41,6 +42,27 @@ namespace Tensorflow
}
}

public NestType NestType => NestType.List;

public int ShallowNestedCount => ndim;
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
public int TotalNestedCount => ndim;

public IEnumerable<long> Flatten() => dims.Select(x => x);

public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func)
{
return new NestList<TOut>(dims.Select(x => func(x)));
}

public Nest<long> AsNest()
{
return new NestList<long>(Flatten()).AsNest();
}

#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges
public int Length => ndim;
public long[] Slice(int start, int length)


+ 2
- 2
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -185,8 +185,8 @@ namespace Tensorflow
{
throw new NotImplementedException();
}
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
public INestStructure<long> StateSize => throw new NotImplementedException();
public INestStructure<long> OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => throw new NotImplementedException();
public bool SupportOptionalArgs => throw new NotImplementedException();
}


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs View File

@@ -18,8 +18,8 @@ namespace Tensorflow.Keras.Layers.Rnn

}

public abstract GeneralizedTensorShape StateSize { get; }
public abstract GeneralizedTensorShape OutputSize { get; }
public abstract INestStructure<long> StateSize { get; }
public abstract INestStructure<long> OutputSize { get; }
public abstract bool SupportOptionalArgs { get; }
public virtual Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype)
{


+ 6
- 15
src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs View File

@@ -22,13 +22,11 @@ namespace Tensorflow.Keras.Layers.Rnn
IVariableV1 _recurrent_kernel;
IInitializer _bias_initializer;
IVariableV1 _bias;
GeneralizedTensorShape _state_size;
GeneralizedTensorShape _output_size;
public override GeneralizedTensorShape StateSize => _state_size;
INestStructure<long> _state_size;
INestStructure<long> _output_size;
public override INestStructure<long> StateSize => _state_size;

public override GeneralizedTensorShape OutputSize => _output_size;

public override bool IsTFRnnCell => true;
public override INestStructure<long> OutputSize => _output_size;

public override bool SupportOptionalArgs => false;
public LSTMCell(LSTMCellArgs args)
@@ -49,10 +47,8 @@ namespace Tensorflow.Keras.Layers.Rnn
_args.Implementation = 1;
}

_state_size = new GeneralizedTensorShape(_args.Units, 2);
_output_size = new GeneralizedTensorShape(_args.Units);


_state_size = new NestList<long>(_args.Units, _args.Units);
_output_size = new NestNode<long>(_args.Units);
}

public override void build(KerasShapesWrapper input_shape)
@@ -229,11 +225,6 @@ namespace Tensorflow.Keras.Layers.Rnn
var o = _args.RecurrentActivation.Apply(z3);
return new Tensors(c, o);
}

public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
}
}



+ 17
- 27
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -86,7 +86,7 @@ namespace Tensorflow.Keras.Layers.Rnn
set { _states = value; }
}

private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
private INestStructure<Shape> compute_output_shape(Shape input_shape)
{
var batch = input_shape[0];
var time_step = input_shape[1];
@@ -96,13 +96,15 @@ namespace Tensorflow.Keras.Layers.Rnn
}

// state_size is a array of ints or a positive integer
var state_size = Cell.StateSize.ToSingleShape();
var state_size = Cell.StateSize;
if(state_size?.TotalNestedCount == 1)
{
state_size = new NestList<long>(state_size.Flatten().First());
}

// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
Func<Shape, Shape> _get_output_shape;
_get_output_shape = (flat_output_size) =>
Func<long, Shape> _get_output_shape = (flat_output_size) =>
{
var output_dim = flat_output_size.as_int_list();
var output_dim = new Shape(flat_output_size).as_int_list();
Shape output_shape;
if (_args.ReturnSequences)
{
@@ -125,31 +127,28 @@ namespace Tensorflow.Keras.Layers.Rnn

Type type = Cell.GetType();
PropertyInfo output_size_info = type.GetProperty("output_size");
Shape output_shape;
INestStructure<Shape> output_shape;
if (output_size_info != null)
{
output_shape = nest.map_structure(_get_output_shape, Cell.OutputSize.ToSingleShape());
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
output_shape = Nest.MapStructure(_get_output_shape, Cell.OutputSize);
}
else
{
output_shape = _get_output_shape(state_size);
output_shape = new NestNode<Shape>(_get_output_shape(state_size.Flatten().First()));
}

if (_args.ReturnState)
{
Func<Shape, Shape> _get_state_shape;
_get_state_shape = (flat_state) =>
Func<long, Shape> _get_state_shape = (flat_state) =>
{
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
var state_shape = new int[] { (int)batch }.concat(new Shape(flat_state).as_int_list());
return new Shape(state_shape);
};


var state_shape = _get_state_shape(state_size);
var state_shape = Nest.MapStructure(_get_state_shape, state_size);

return new List<Shape> { output_shape, state_shape };
return new Nest<Shape>(new[] { output_shape, state_shape } );
}
else
{
@@ -435,7 +434,7 @@ namespace Tensorflow.Keras.Layers.Rnn
tmp.add(tf.math.count_nonzero(s.Single()));
}
var non_zero_count = tf.add_n(tmp);
//initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
initial_state = tf.cond(non_zero_count > 0, States, initial_state);
if ((int)non_zero_count.numpy() > 0)
{
initial_state = States;
@@ -445,16 +444,7 @@ namespace Tensorflow.Keras.Layers.Rnn
{
initial_state = States;
}
// TODO(Wanglongzhi2001),
// initial_state = tf.nest.map_structure(
//# When the layer has a inferred dtype, use the dtype from the
//# cell.
// lambda v: tf.cast(
// v, self.compute_dtype or self.cell.compute_dtype
// ),
// initial_state,
// )

//initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state);
}
else if (initial_state is null)
{


+ 6
- 6
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -24,11 +24,11 @@ namespace Tensorflow.Keras.Layers.Rnn
IVariableV1 _kernel;
IVariableV1 _recurrent_kernel;
IVariableV1 _bias;
GeneralizedTensorShape _state_size;
GeneralizedTensorShape _output_size;
INestStructure<long> _state_size;
INestStructure<long> _output_size;

public override GeneralizedTensorShape StateSize => _state_size;
public override GeneralizedTensorShape OutputSize => _output_size;
public override INestStructure<long> StateSize => _state_size;
public override INestStructure<long> OutputSize => _output_size;
public override bool SupportOptionalArgs => false;

public SimpleRNNCell(SimpleRNNCellArgs args) : base(args)
@@ -41,8 +41,8 @@ namespace Tensorflow.Keras.Layers.Rnn
}
this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
_state_size = new GeneralizedTensorShape(args.Units);
_output_size = new GeneralizedTensorShape(args.Units);
_state_size = new NestNode<long>(args.Units);
_output_size = new NestNode<long>(args.Units);
}

public override void build(KerasShapesWrapper input_shape)


+ 9
- 11
src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -1,10 +1,8 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -38,24 +36,24 @@ namespace Tensorflow.Keras.Layers.Rnn

public bool SupportOptionalArgs => false;

public GeneralizedTensorShape StateSize
public INestStructure<long> StateSize
{
get
{
if (_reverse_state_order)
{
var state_sizes = Cells.Reverse().Select(cell => cell.StateSize);
return new GeneralizedTensorShape(new Nest<Shape>(state_sizes.Select(s => new Nest<Shape>(s))));
return new Nest<long>(state_sizes);
}
else
{
var state_sizes = Cells.Select(cell => cell.StateSize);
return new GeneralizedTensorShape(new Nest<Shape>(state_sizes.Select(s => new Nest<Shape>(s))));
return new Nest<long>(state_sizes);
}
}
}

public GeneralizedTensorShape OutputSize
public INestStructure<long> OutputSize
{
get
{
@@ -66,7 +64,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}
else if (RnnUtils.is_multiple_state(lastCell.StateSize))
{
return lastCell.StateSize.First();
return new NestNode<long>(lastCell.StateSize.Flatten().First());
}
else
{
@@ -89,7 +87,7 @@ namespace Tensorflow.Keras.Layers.Rnn
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
// Recover per-cell states.
var state_size = _reverse_state_order ? new GeneralizedTensorShape(StateSize.Reverse()) : StateSize;
var state_size = _reverse_state_order ? new NestList<long>(StateSize.Flatten().Reverse()) : StateSize;
var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray());

var new_nest_states = Nest<Tensor>.Empty;
@@ -118,20 +116,20 @@ namespace Tensorflow.Keras.Layers.Rnn
layer.build(shape);
layer.Built = true;
}
GeneralizedTensorShape output_dim;
INestStructure<long> output_dim;
if(cell.OutputSize is not null)
{
output_dim = cell.OutputSize;
}
else if (RnnUtils.is_multiple_state(cell.StateSize))
{
output_dim = cell.StateSize.First();
output_dim = new NestNode<long>(cell.StateSize.Flatten().First());
}
else
{
output_dim = cell.StateSize;
}
shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.ToSingleShape().dims).ToArray());
shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.Flatten()).ToArray());
}
this.Built = true;
}


+ 6
- 7
src/TensorFlowNET.Keras/Utils/RnnUtils.cs View File

@@ -10,12 +10,11 @@ namespace Tensorflow.Keras.Utils
{
internal static class RnnUtils
{
internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype)
internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, INestStructure<long> state_size, TF_DataType dtype)
{
Func<GeneralizedTensorShape, Tensor> create_zeros;
create_zeros = (GeneralizedTensorShape unnested_state_size) =>
Func<long, Tensor> create_zeros = (unnested_state_size) =>
{
var flat_dims = unnested_state_size.ToSingleShape().dims;
var flat_dims = new Shape(unnested_state_size).dims;
var init_state_size = new Tensor[] { batch_size_tensor }.
Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray();
return array_ops.zeros(init_state_size, dtype: dtype);
@@ -24,11 +23,11 @@ namespace Tensorflow.Keras.Utils
// TODO(Rinne): map structure with nested tensors.
if(state_size.TotalNestedCount > 1)
{
return new Tensors(state_size.Flatten().Select(s => create_zeros(new GeneralizedTensorShape(s))).ToArray());
return new Tensors(state_size.Flatten().Select(s => create_zeros(s)).ToArray());
}
else
{
return create_zeros(state_size);
return create_zeros(state_size.Flatten().First());
}

}
@@ -96,7 +95,7 @@ namespace Tensorflow.Keras.Utils
/// </summary>
/// <param name="state_size"></param>
/// <returns></returns>
public static bool is_multiple_state(GeneralizedTensorShape state_size)
public static bool is_multiple_state(INestStructure<long> state_size)
{
return state_size.TotalNestedCount > 1;
}


Loading…
Cancel
Save