| @@ -1,7 +1,35 @@ | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| using Newtonsoft.Json; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| // TODO: complete the implementation | |||
| public class LSTMCellArgs : LayerArgs | |||
| public class LSTMCellArgs : AutoSerializeLayerArgs | |||
| { | |||
| [JsonProperty("units")] | |||
| public int Units { get; set; } | |||
| // TODO(Rinne): lack of initialized value of Activation. Merging keras | |||
| // into tf.net could resolve it. | |||
| [JsonProperty("activation")] | |||
| public Activation Activation { get; set; } | |||
| [JsonProperty("recurrent_activation")] | |||
| public Activation RecurrentActivation { get; set; } | |||
| [JsonProperty("use_bias")] | |||
| public bool UseBias { get; set; } = true; | |||
| [JsonProperty("dropout")] | |||
| public float Dropout { get; set; } = .0f; | |||
| [JsonProperty("recurrent_dropout")] | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| [JsonProperty("kernel_initializer")] | |||
| public IInitializer KernelInitializer { get; set; } | |||
| [JsonProperty("recurrent_initializer")] | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| [JsonProperty("bias_initializer")] | |||
| public IInitializer BiasInitializer { get; set; } | |||
| [JsonProperty("unit_forget_bias")] | |||
| public bool UnitForgetBias { get; set; } = true; | |||
| [JsonProperty("implementation")] | |||
| public int Implementation { get; set; } = 2; | |||
| } | |||
| } | |||
| @@ -1,7 +1,4 @@ | |||
| using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| @@ -25,5 +22,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| [JsonProperty("bias_initializer")] | |||
| public IInitializer BiasInitializer { get; set; } | |||
| } | |||
| } | |||
| @@ -160,6 +160,18 @@ namespace Tensorflow.Keras.Layers | |||
| public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | |||
| public ILayer LeakyReLU(float alpha = 0.3f); | |||
| public IRnnCell LSTMCell(int uints, | |||
| string activation = "tanh", | |||
| string recurrent_activation = "sigmoid", | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string recurrent_initializer = "orthogonal", | |||
| string bias_initializer = "zeros", | |||
| bool unit_forget_bias = true, | |||
| float dropout = 0f, | |||
| float recurrent_dropout = 0f, | |||
| int implementation = 2); | |||
| public ILayer LSTM(int units, | |||
| Activation activation = null, | |||
| Activation recurrent_activation = null, | |||
| @@ -58,8 +58,7 @@ public class Orthogonal : IInitializer | |||
| if (num_rows < num_cols) | |||
| { | |||
| // q = tf.linalg.matrix_transpose(q); | |||
| throw new NotImplementedException(""); | |||
| q = array_ops.matrix_transpose(q); | |||
| } | |||
| return _gain * tf.reshape(q, shape); | |||
| @@ -971,6 +971,49 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Transposes last two dimensions of tensor `a`. | |||
| /// For example: | |||
| /// <code> python | |||
| /// x = tf.constant([[1, 2, 3], [4, 5, 6]]) | |||
| /// tf.matrix_transpose(x) # [[1, 4], | |||
| /// # [2, 5], | |||
| /// # [3, 6]] | |||
| /// </code> | |||
| /// Matrix with two batch dimensions. | |||
| /// x.shape is [1, 2, 3, 4] | |||
| /// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3] | |||
| /// </summary> | |||
| /// <param name="a"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="conjugate"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ValueError"></exception> | |||
| public static Tensor matrix_transpose(Tensor a, string name = "matrix_transpose", bool conjugate = false) | |||
| { | |||
| return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | |||
| { | |||
| var a_shape = a.shape; | |||
| var ndims = a.shape.ndim; | |||
| Axis perm; | |||
| if(ndims != 0) | |||
| { | |||
| if (ndims < 2) | |||
| { | |||
| throw new ValueError("Argument `a` should be a (batch) matrix with rank " + | |||
| $">= 2. Received `a` = {a} with shape: {a_shape}"); | |||
| } | |||
| perm = new Axis(Enumerable.Range(0, ndims - 2).Concat(new int[] { ndims - 1, ndims - 2 }).ToArray()); | |||
| } | |||
| else | |||
| { | |||
| var a_rank = a.rank; | |||
| perm = new Axis(Enumerable.Range(0, a_rank - 2).Concat(new int[] { a_rank - 1, a_rank - 2 }).ToArray()); | |||
| } | |||
| return transpose(a, perm:perm, conjugate:conjugate); | |||
| }); | |||
| } | |||
| public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1, | |||
| string name = "split") | |||
| { | |||
| @@ -702,6 +702,7 @@ namespace Tensorflow.Keras.Layers | |||
| UseBias = use_bias, | |||
| KernelInitializer = GetInitializerByName(kernel_initializer), | |||
| RecurrentInitializer = GetInitializerByName(recurrent_initializer), | |||
| BiasInitializer = GetInitializerByName(bias_initializer), | |||
| Dropout = dropout, | |||
| RecurrentDropout = recurrent_dropout | |||
| }); | |||
| @@ -786,6 +787,33 @@ namespace Tensorflow.Keras.Layers | |||
| TimeMajor = time_major | |||
| }); | |||
| public IRnnCell LSTMCell(int uints, | |||
| string activation = "tanh", | |||
| string recurrent_activation = "sigmoid", | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed. | |||
| string bias_initializer = "zeros", | |||
| bool unit_forget_bias = true, | |||
| float dropout = 0f, | |||
| float recurrent_dropout = 0f, | |||
| int implementation = 2) | |||
| => new LSTMCell(new LSTMCellArgs | |||
| { | |||
| Units = uints, | |||
| Activation = keras.activations.GetActivationFromName(activation), | |||
| RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation), | |||
| UseBias = use_bias, | |||
| KernelInitializer = GetInitializerByName(kernel_initializer), | |||
| RecurrentInitializer = GetInitializerByName(recurrent_initializer), | |||
| BiasInitializer = GetInitializerByName(bias_initializer), | |||
| UnitForgetBias = unit_forget_bias, | |||
| Dropout = dropout, | |||
| RecurrentDropout = recurrent_dropout, | |||
| Implementation = implementation | |||
| }); | |||
| /// <summary> | |||
| /// Long Short-Term Memory layer - Hochreiter 1997. | |||
| /// </summary> | |||
| @@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| } | |||
| public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | |||
| public Tensors? get_dropout_mask_for_cell(Tensors input, bool training, int count = 1) | |||
| { | |||
| if (dropout == 0f) | |||
| return null; | |||
| @@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| } | |||
| // Get the recurrent dropout mask for RNN cell. | |||
| public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | |||
| public Tensors? get_recurrent_dropout_mask_for_cell(Tensors input, bool training, int count = 1) | |||
| { | |||
| if (dropout == 0f) | |||
| return null; | |||
| @@ -1,16 +1,240 @@ | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Serilog.Core; | |||
| using System.Diagnostics; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Keras.Utils; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public class LSTMCell : Layer | |||
| /// <summary> | |||
| /// Cell class for the LSTM layer. | |||
| /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) | |||
| /// for details about the usage of RNN API. | |||
| /// This class processes one step within the whole time sequence input, whereas | |||
| /// `tf.keras.layer.LSTM` processes the whole sequence. | |||
| /// </summary> | |||
| public class LSTMCell : DropoutRNNCellMixin | |||
| { | |||
| LSTMCellArgs args; | |||
| LSTMCellArgs _args; | |||
| IVariableV1 _kernel; | |||
| IVariableV1 _recurrent_kernel; | |||
| IInitializer _bias_initializer; | |||
| IVariableV1 _bias; | |||
| GeneralizedTensorShape _state_size; | |||
| GeneralizedTensorShape _output_size; | |||
| public override GeneralizedTensorShape StateSize => _state_size; | |||
| public override GeneralizedTensorShape OutputSize => _output_size; | |||
| public override bool IsTFRnnCell => true; | |||
| public override bool SupportOptionalArgs => false; | |||
| public LSTMCell(LSTMCellArgs args) | |||
| : base(args) | |||
| { | |||
| this.args = args; | |||
| _args = args; | |||
| if (args.Units <= 0) | |||
| { | |||
| throw new ValueError( | |||
| $"units must be a positive integer, got {args.Units}"); | |||
| } | |||
| _args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); | |||
| _args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); | |||
| if (_args.RecurrentDropout != 0f && _args.Implementation != 1) | |||
| { | |||
| Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." + | |||
| "Using `implementation=1`."); | |||
| _args.Implementation = 1; | |||
| } | |||
| _state_size = new GeneralizedTensorShape(_args.Units, 2); | |||
| _output_size = new GeneralizedTensorShape(_args.Units); | |||
| } | |||
| public override void build(KerasShapesWrapper input_shape) | |||
| { | |||
| var single_shape = input_shape.ToSingleShape(); | |||
| var input_dim = single_shape[-1]; | |||
| _kernel = add_weight("kernel", (input_dim, _args.Units * 4), | |||
| initializer: _args.KernelInitializer | |||
| ); | |||
| _recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units * 4), | |||
| initializer: _args.RecurrentInitializer | |||
| ); | |||
| if (_args.UseBias) | |||
| { | |||
| if (_args.UnitForgetBias) | |||
| { | |||
| Tensor bias_initializer() | |||
| { | |||
| return keras.backend.concatenate( | |||
| new Tensors( | |||
| _args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units))), | |||
| tf.ones_initializer.Apply(new InitializerArgs(shape: (_args.Units))), | |||
| _args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units)))), axis: 0); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| _bias_initializer = _args.BiasInitializer; | |||
| } | |||
| _bias = add_weight("bias", (_args.Units * 4), | |||
| initializer: _args.BiasInitializer); | |||
| } | |||
| built = true; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| var h_tm1 = states[0]; // previous memory state | |||
| var c_tm1 = states[1]; // previous carry state | |||
| var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 4); | |||
| var rec_dp_mask = get_recurrent_dropout_mask_for_cell( | |||
| h_tm1, training.Value, count: 4); | |||
| Tensor c; | |||
| Tensor o; | |||
| if (_args.Implementation == 1) | |||
| { | |||
| Tensor inputs_i; | |||
| Tensor inputs_f; | |||
| Tensor inputs_c; | |||
| Tensor inputs_o; | |||
| if (0f < _args.Dropout && _args.Dropout < 1f) | |||
| { | |||
| inputs_i = inputs * dp_mask[0]; | |||
| inputs_f = inputs * dp_mask[1]; | |||
| inputs_c = inputs * dp_mask[2]; | |||
| inputs_o = inputs * dp_mask[3]; | |||
| } | |||
| else | |||
| { | |||
| inputs_i = inputs; | |||
| inputs_f = inputs; | |||
| inputs_c = inputs; | |||
| inputs_o = inputs; | |||
| } | |||
| var k = tf.split(_kernel.AsTensor(), num_split: 4, axis: 1); | |||
| Tensor k_i = k[0], k_f = k[1], k_c = k[2], k_o = k[3]; | |||
| var x_i = math_ops.matmul(inputs_i, k_i); | |||
| var x_f = math_ops.matmul(inputs_f, k_f); | |||
| var x_c = math_ops.matmul(inputs_c, k_c); | |||
| var x_o = math_ops.matmul(inputs_o, k_o); | |||
| if(_args.UseBias) | |||
| { | |||
| var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0); | |||
| Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3]; | |||
| x_i = gen_nn_ops.bias_add(x_i, b_i); | |||
| x_f = gen_nn_ops.bias_add(x_f, b_f); | |||
| x_c = gen_nn_ops.bias_add(x_c, b_c); | |||
| x_o = gen_nn_ops.bias_add(x_o, b_o); | |||
| } | |||
| Tensor h_tm1_i; | |||
| Tensor h_tm1_f; | |||
| Tensor h_tm1_c; | |||
| Tensor h_tm1_o; | |||
| if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f) | |||
| { | |||
| h_tm1_i = h_tm1 * rec_dp_mask[0]; | |||
| h_tm1_f = h_tm1 * rec_dp_mask[1]; | |||
| h_tm1_c = h_tm1 * rec_dp_mask[2]; | |||
| h_tm1_o = h_tm1 * rec_dp_mask[3]; | |||
| } | |||
| else | |||
| { | |||
| h_tm1_i = h_tm1; | |||
| h_tm1_f = h_tm1; | |||
| h_tm1_c = h_tm1; | |||
| h_tm1_o = h_tm1; | |||
| } | |||
| var x = new Tensor[] { x_i, x_f, x_c, x_o }; | |||
| var h_tm1_array = new Tensor[] { h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o }; | |||
| (c, o) = _compute_carry_and_output(x, h_tm1_array, c_tm1); | |||
| } | |||
| else | |||
| { | |||
| if (0f < _args.Dropout && _args.Dropout < 1f) | |||
| inputs = inputs * dp_mask[0]; | |||
| var z = math_ops.matmul(inputs, _kernel.AsTensor()); | |||
| z += math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor()); | |||
| if (_args.UseBias) | |||
| { | |||
| z = tf.nn.bias_add(z, _bias); | |||
| } | |||
| var z_array = tf.split(z, num_split: 4, axis: 1); | |||
| (c, o) = _compute_carry_and_output_fused(z_array, c_tm1); | |||
| } | |||
| var h = o * _args.Activation.Apply(c); | |||
| // 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组 | |||
| return new Tensors(h, h, c); | |||
| } | |||
| /// <summary> | |||
| /// Computes carry and output using split kernels. | |||
| /// </summary> | |||
| /// <param name="x"></param> | |||
| /// <param name="h_tm1"></param> | |||
| /// <param name="c_tm1"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="NotImplementedException"></exception> | |||
| public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm1) | |||
| { | |||
| Tensor x_i = x[0], x_f = x[1], x_c = x[2], x_o = x[3]; | |||
| Tensor h_tm1_i = h_tm1[0], h_tm1_f = h_tm1[1], h_tm1_c = h_tm1[2], | |||
| h_tm1_o = h_tm1[3]; | |||
| var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor(); | |||
| var startIndex = _recurrent_kernel_tensor.shape[0]; | |||
| var endIndex = _recurrent_kernel_tensor.shape[1]; | |||
| var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||
| new[] { 0, 0 }, new[] { startIndex, _args.Units }); | |||
| var i = _args.RecurrentActivation.Apply( | |||
| x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice)); | |||
| _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||
| new[] { 0, _args.Units }, new[] { startIndex, _args.Units * 2}); | |||
| var f = _args.RecurrentActivation.Apply( | |||
| x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice)); | |||
| _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||
| new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units * 3 }); | |||
| var c = f * c_tm1 + i * _args.Activation.Apply( | |||
| x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice)); | |||
| _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||
| new[] { 0, _args.Units * 3 }, new[] { startIndex, endIndex }); | |||
| var o = _args.RecurrentActivation.Apply( | |||
| x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice)); | |||
| return new Tensors(c, o); | |||
| } | |||
| /// <summary> | |||
| /// Computes carry and output using fused kernels. | |||
| /// </summary> | |||
| /// <param name="z"></param> | |||
| /// <param name="c_tm1"></param> | |||
| /// <returns></returns> | |||
| public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1) | |||
| { | |||
| Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3]; | |||
| var i = _args.RecurrentActivation.Apply(z0); | |||
| var f = _args.RecurrentActivation.Apply(z1); | |||
| var c = f * c_tm1 + i * _args.RecurrentActivation.Apply(z2); | |||
| var o = _args.RecurrentActivation.Apply(z3); | |||
| return new Tensors(c, o); | |||
| } | |||
| public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) | |||
| { | |||
| return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); | |||
| } | |||
| } | |||
| } | |||
| @@ -74,8 +74,8 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| // TODO(Rinne): check if it will have multiple tensors when not nested. | |||
| Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; | |||
| var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value); | |||
| var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); | |||
| var dp_mask = get_dropout_mask_for_cell(inputs, training.Value); | |||
| var rec_dp_mask = get_recurrent_dropout_mask_for_cell(prev_output, training.Value); | |||
| Tensor h; | |||
| var ranks = inputs.rank; | |||
| @@ -21,21 +21,6 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||
| [TestMethod] | |||
| public void SimpleRNNCell() | |||
| { | |||
| //var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); | |||
| //var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | |||
| //var x = tf.random.normal((4, 100)); | |||
| //var (y, h1) = cell.Apply(inputs: x, states: h0); | |||
| //var h2 = h1; | |||
| //Assert.AreEqual((4, 64), y.shape); | |||
| //Assert.AreEqual((4, 64), h2[0].shape); | |||
| //var model = keras.Sequential(new List<ILayer> | |||
| //{ | |||
| // keras.layers.InputLayer(input_shape: (4,100)), | |||
| // keras.layers.SimpleRNNCell(64) | |||
| //}); | |||
| //model.summary(); | |||
| var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); | |||
| var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | |||
| var x = tf.random.normal((4, 100)); | |||
| @@ -59,6 +44,17 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||
| Assert.AreEqual((32, 4), state[0].shape); | |||
| } | |||
| [TestMethod] | |||
| public void LSTMCell() | |||
| { | |||
| var inputs = tf.ones((2, 100)); | |||
| var states = new Tensors { tf.zeros((2, 4)), tf.zeros((2, 4)) }; | |||
| var rnn = tf.keras.layers.LSTMCell(4); | |||
| var (output, new_states) = rnn.Apply(inputs, states); | |||
| Assert.AreEqual((2, 4), output.shape); | |||
| Assert.AreEqual((2, 4), new_states[0].shape); | |||
| } | |||
| [TestMethod] | |||
| public void SimpleRNN() | |||
| { | |||
| @@ -105,15 +101,27 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||
| } | |||
| [TestMethod] | |||
| public void WlzTest() | |||
| public void RNNForLSTMCell() | |||
| { | |||
| long[] b = { 1, 2, 3 }; | |||
| Shape a = new Shape(Unknown).concatenate(b); | |||
| Console.WriteLine(a); | |||
| var inputs = tf.ones((5, 10, 8)); | |||
| var rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4)); | |||
| var output = rnn.Apply(inputs); | |||
| Console.WriteLine($"output: {output}"); | |||
| Assert.AreEqual((5, 4), output.shape); | |||
| } | |||
| [TestMethod] | |||
| public void MyTest() | |||
| { | |||
| var a = tf.zeros((2, 3)); | |||
| var b = tf.ones_like(a); | |||
| var c = tf.ones((3,4)); | |||
| var d = new Tensors { a, b, c }; | |||
| var (A, BC) = d; | |||
| Console.WriteLine($"A:{A}"); | |||
| Console.WriteLine($"BC:{BC}"); | |||
| } | |||
| } | |||
| } | |||