feat: implement GRU layertags/v0.110.4-Transformer-Model
| @@ -0,0 +1,29 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class GRUArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| public int Units { get; set; } | |||||
| public Activation Activation { get; set; } | |||||
| public Activation RecurrentActivation { get; set; } | |||||
| public bool UseBias { get; set; } = true; | |||||
| public float Dropout { get; set; } = .0f; | |||||
| public float RecurrentDropout { get; set; } = .0f; | |||||
| public IInitializer KernelInitializer { get; set; } | |||||
| public IInitializer RecurrentInitializer { get; set; } | |||||
| public IInitializer BiasInitializer { get; set; } | |||||
| public bool ReturnSequences { get;set; } | |||||
| public bool ReturnState { get;set; } | |||||
| public bool GoBackwards { get;set; } | |||||
| public bool Stateful { get;set; } | |||||
| public bool Unroll { get;set; } | |||||
| public bool TimeMajor { get;set; } | |||||
| public bool ResetAfter { get;set; } | |||||
| public int Implementation { get; set; } = 2; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,13 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class GRUOptionalArgs | |||||
| { | |||||
| public string Identifier => "GRU"; | |||||
| public Tensor Mask { get; set; } = null; | |||||
| } | |||||
| } | |||||
| @@ -259,6 +259,25 @@ namespace Tensorflow.Keras.Layers | |||||
| float recurrent_dropout = 0f, | float recurrent_dropout = 0f, | ||||
| bool reset_after = true); | bool reset_after = true); | ||||
| public ILayer GRU( | |||||
| int units, | |||||
| string activation = "tanh", | |||||
| string recurrent_activation = "sigmoid", | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros", | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false, | |||||
| bool reset_after = true | |||||
| ); | |||||
| /// <summary> | /// <summary> | ||||
| /// Bidirectional wrapper for RNNs. | /// Bidirectional wrapper for RNNs. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -784,7 +784,7 @@ namespace Tensorflow.Keras.Layers | |||||
| string recurrent_activation = "sigmoid", | string recurrent_activation = "sigmoid", | ||||
| bool use_bias = true, | bool use_bias = true, | ||||
| string kernel_initializer = "glorot_uniform", | string kernel_initializer = "glorot_uniform", | ||||
| string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed. | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros", | string bias_initializer = "zeros", | ||||
| bool unit_forget_bias = true, | bool unit_forget_bias = true, | ||||
| float dropout = 0f, | float dropout = 0f, | ||||
| @@ -908,6 +908,65 @@ namespace Tensorflow.Keras.Layers | |||||
| ResetAfter = reset_after | ResetAfter = reset_after | ||||
| }); | }); | ||||
| /// <summary> | |||||
| /// Gated Recurrent Unit - Cho et al. 2014. | |||||
| /// </summary> | |||||
| /// <param name="units">Positive integer, dimensionality of the output space.</param> | |||||
| /// <param name="activation">Activation function to use. If you pass `None`, no activation is applied.(ie. "linear" activation: `a(x) = x`).</param> | |||||
| /// <param name="recurrent_activation">Activation function to use for the recurrent step. If you pass `None`, no activation is applied. (ie. "linear" activation: `a(x) = x`).</param> | |||||
| /// <param name="use_bias">Boolean, (default `True`), whether the layer uses a bias vector.</param> | |||||
| /// <param name="kernel_initializer">Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`.</param> | |||||
| /// <param name="recurrent_initializer">Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`.</param> | |||||
| /// <param name="bias_initializer">Initializer for the bias vector. Default: `zeros`.</param> | |||||
| /// <param name="dropout">Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0.</param> | |||||
| /// <param name="recurrent_dropout">Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0.</param> | |||||
| /// <param name="implementation"></param> | |||||
| /// <param name="return_sequences">Boolean. Whether to return the last output in the output sequence, or the full sequence. Default: `False`.</param> | |||||
| /// <param name="return_state">Boolean. Whether to return the last state in addition to the output. Default: `False`.</param> | |||||
| /// <param name="go_backwards">Boolean (default `False`). If True, process the input sequence backwards and return the reversed sequence.</param> | |||||
| /// <param name="stateful">Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.</param> | |||||
| /// <param name="unroll">Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN,</param> | |||||
| /// <param name="time_major">The shape format of the `inputs` and `outputs` tensors.</param> | |||||
| /// <param name="reset_after">GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before", True = "after" (default and cuDNN compatible).</param> | |||||
| /// <returns></returns> | |||||
| public ILayer GRU( | |||||
| int units, | |||||
| string activation = "tanh", | |||||
| string recurrent_activation = "sigmoid", | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros", | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false, | |||||
| bool reset_after = true | |||||
| ) | |||||
| => new GRU(new GRUArgs | |||||
| { | |||||
| Units = units, | |||||
| Activation = keras.activations.GetActivationFromName(activation), | |||||
| RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation), | |||||
| KernelInitializer = GetInitializerByName(kernel_initializer), | |||||
| RecurrentInitializer = GetInitializerByName(recurrent_initializer), | |||||
| BiasInitializer = GetInitializerByName(bias_initializer), | |||||
| UseBias = use_bias, | |||||
| Dropout = dropout, | |||||
| RecurrentDropout = recurrent_dropout, | |||||
| ReturnSequences = return_sequences, | |||||
| ReturnState = return_state, | |||||
| GoBackwards = go_backwards, | |||||
| Stateful = stateful, | |||||
| TimeMajor = time_major, | |||||
| Unroll = unroll, | |||||
| ResetAfter = reset_after | |||||
| }); | |||||
| public ILayer Bidirectional( | public ILayer Bidirectional( | ||||
| ILayer layer, | ILayer layer, | ||||
| string merge_mode = "concat", | string merge_mode = "concat", | ||||
| @@ -0,0 +1,168 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Common.Extensions; | |||||
| using Tensorflow.Common.Types; | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public class GRU : RNN | |||||
| { | |||||
| GRUArgs _args; | |||||
| private static GRUCell _cell; | |||||
| bool _return_runtime; | |||||
| public GRUCell Cell { get => _cell; } | |||||
| public int units { get => _args.Units; } | |||||
| public Activation activation { get => _args.Activation; } | |||||
| public Activation recurrent_activation { get => _args.RecurrentActivation; } | |||||
| public bool use_bias { get => _args.UseBias; } | |||||
| public float dropout { get => _args.Dropout; } | |||||
| public float recurrent_dropout { get => _args.RecurrentDropout; } | |||||
| public IInitializer kernel_initializer { get => _args.KernelInitializer; } | |||||
| public IInitializer recurrent_initializer { get => _args.RecurrentInitializer; } | |||||
| public IInitializer bias_initializer { get => _args.BiasInitializer; } | |||||
| public int implementation { get => _args.Implementation; } | |||||
| public bool reset_after { get => _args.ResetAfter; } | |||||
| public GRU(GRUArgs args) : base(CreateCell(args), PreConstruct(args)) | |||||
| { | |||||
| _args = args; | |||||
| if (_args.Implementation == 0) | |||||
| { | |||||
| // Use the red output to act as a warning message that can also be used under the release version | |||||
| Console.ForegroundColor = ConsoleColor.Red; | |||||
| Console.WriteLine("Warning: `implementation=0` has been deprecated, "+ | |||||
| "and now defaults to `implementation=2`."+ | |||||
| "Please update your layer call."); | |||||
| Console.ResetColor(); | |||||
| } | |||||
| GRUCell cell = new GRUCell(new GRUCellArgs | |||||
| { | |||||
| Units = _args.Units, | |||||
| Activation = _args.Activation, | |||||
| RecurrentActivation = _args.RecurrentActivation, | |||||
| UseBias = _args.UseBias, | |||||
| Dropout = _args.Dropout, | |||||
| RecurrentDropout = _args.RecurrentDropout, | |||||
| KernelInitializer = _args.KernelInitializer, | |||||
| RecurrentInitializer = _args.RecurrentInitializer, | |||||
| BiasInitializer = _args.BiasInitializer, | |||||
| ResetAfter = _args.ResetAfter, | |||||
| Implementation = _args.Implementation | |||||
| }); | |||||
| _cell = cell; | |||||
| } | |||||
| protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
| { | |||||
| GRUOptionalArgs? gru_optional_args = optional_args as GRUOptionalArgs; | |||||
| if (optional_args is not null && gru_optional_args is null) | |||||
| { | |||||
| throw new ArgumentException("The type of optional args should be `GRUOptionalArgs`."); | |||||
| } | |||||
| Tensors? mask = gru_optional_args?.Mask; | |||||
| // Not support ragger input temporarily; | |||||
| int row_length = 0; | |||||
| bool is_ragged_input = false; | |||||
| _validate_args_if_ragged(is_ragged_input, mask); | |||||
| // GRU does not support constants.Ignore it during process. | |||||
| (inputs, initial_state, _) = this._process_inputs(inputs, initial_state, null); | |||||
| if (mask.Length > 1) | |||||
| { | |||||
| mask = mask[0]; | |||||
| } | |||||
| var input_shape = inputs.shape; | |||||
| var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; | |||||
| // TODO(Wanglongzhi2001), finish _could_use_gpu_kernel part | |||||
| Func<Tensors, Tensors, (Tensors, Tensors)> step = (cell_inputs, cell_states) => | |||||
| { | |||||
| var res = Cell.Apply(cell_inputs, cell_states, training is null ? true : training.Value); | |||||
| var (output, state) = res; | |||||
| return (output, state); | |||||
| }; | |||||
| var (last_output, outputs, states) = keras.backend.rnn( | |||||
| step, | |||||
| inputs, | |||||
| initial_state, | |||||
| constants: null, | |||||
| go_backwards: _args.GoBackwards, | |||||
| mask: mask, | |||||
| unroll: _args.Unroll, | |||||
| input_length: ops.convert_to_tensor(timesteps), | |||||
| time_major: _args.TimeMajor, | |||||
| zero_output_for_mask: base.Args.ZeroOutputForMask, | |||||
| return_all_outputs: _args.ReturnSequences | |||||
| ); | |||||
| Tensors output; | |||||
| if (_args.ReturnSequences) | |||||
| { | |||||
| output = outputs; | |||||
| } | |||||
| else | |||||
| { | |||||
| output = last_output; | |||||
| } | |||||
| if (_args.ReturnState) | |||||
| { | |||||
| output = new Tensors { output, states }; | |||||
| } | |||||
| return output; | |||||
| } | |||||
| private static IRnnCell CreateCell(GRUArgs gruArgs) | |||||
| { | |||||
| return new GRUCell(new GRUCellArgs | |||||
| { | |||||
| Units = gruArgs.Units, | |||||
| Activation = gruArgs.Activation, | |||||
| RecurrentActivation = gruArgs.RecurrentActivation, | |||||
| UseBias = gruArgs.UseBias, | |||||
| Dropout = gruArgs.Dropout, | |||||
| RecurrentDropout = gruArgs.RecurrentDropout, | |||||
| KernelInitializer = gruArgs.KernelInitializer, | |||||
| RecurrentInitializer = gruArgs.RecurrentInitializer, | |||||
| BiasInitializer = gruArgs.BiasInitializer, | |||||
| ResetAfter = gruArgs.ResetAfter, | |||||
| Implementation = gruArgs.Implementation | |||||
| }); | |||||
| } | |||||
| private static RNNArgs PreConstruct(GRUArgs args) | |||||
| { | |||||
| return new RNNArgs | |||||
| { | |||||
| ReturnSequences = args.ReturnSequences, | |||||
| ReturnState = args.ReturnState, | |||||
| GoBackwards = args.GoBackwards, | |||||
| Stateful = args.Stateful, | |||||
| Unroll = args.Unroll, | |||||
| TimeMajor = args.TimeMajor, | |||||
| Units = args.Units, | |||||
| Activation = args.Activation, | |||||
| RecurrentActivation = args.RecurrentActivation, | |||||
| UseBias = args.UseBias, | |||||
| Dropout = args.Dropout, | |||||
| RecurrentDropout = args.RecurrentDropout, | |||||
| KernelInitializer = args.KernelInitializer, | |||||
| RecurrentInitializer = args.RecurrentInitializer, | |||||
| BiasInitializer = args.BiasInitializer | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -25,8 +25,8 @@ namespace Tensorflow.Keras.Layers | |||||
| private RNNArgs _args; | private RNNArgs _args; | ||||
| private object _input_spec = null; // or NoneValue?? | private object _input_spec = null; // or NoneValue?? | ||||
| private object _state_spec = null; | private object _state_spec = null; | ||||
| private Tensors _states = null; | |||||
| private object _constants_spec = null; | private object _constants_spec = null; | ||||
| private Tensors _states = null; | |||||
| private int _num_constants; | private int _num_constants; | ||||
| protected IVariableV1 _kernel; | protected IVariableV1 _kernel; | ||||
| protected IVariableV1 _bias; | protected IVariableV1 _bias; | ||||
| @@ -469,7 +469,7 @@ namespace Tensorflow.Keras.Layers | |||||
| return (inputs, initial_state, constants); | return (inputs, initial_state, constants); | ||||
| } | } | ||||
| private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) | |||||
| protected void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) | |||||
| { | { | ||||
| if (!is_ragged_input) | if (!is_ragged_input) | ||||
| { | { | ||||
| @@ -528,44 +528,6 @@ namespace Tensorflow.Keras.Layers | |||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| // 好像不能cell不能传接口类型 | |||||
| //public RNN New(IRnnArgCell cell, | |||||
| // bool return_sequences = false, | |||||
| // bool return_state = false, | |||||
| // bool go_backwards = false, | |||||
| // bool stateful = false, | |||||
| // bool unroll = false, | |||||
| // bool time_major = false) | |||||
| // => new RNN(new RNNArgs | |||||
| // { | |||||
| // Cell = cell, | |||||
| // ReturnSequences = return_sequences, | |||||
| // ReturnState = return_state, | |||||
| // GoBackwards = go_backwards, | |||||
| // Stateful = stateful, | |||||
| // Unroll = unroll, | |||||
| // TimeMajor = time_major | |||||
| // }); | |||||
| //public RNN New(List<IRnnArgCell> cell, | |||||
| // bool return_sequences = false, | |||||
| // bool return_state = false, | |||||
| // bool go_backwards = false, | |||||
| // bool stateful = false, | |||||
| // bool unroll = false, | |||||
| // bool time_major = false) | |||||
| // => new RNN(new RNNArgs | |||||
| // { | |||||
| // Cell = cell, | |||||
| // ReturnSequences = return_sequences, | |||||
| // ReturnState = return_state, | |||||
| // GoBackwards = go_backwards, | |||||
| // Stateful = stateful, | |||||
| // Unroll = unroll, | |||||
| // TimeMajor = time_major | |||||
| // }); | |||||
| protected Tensors get_initial_state(Tensors inputs) | protected Tensors get_initial_state(Tensors inputs) | ||||
| { | { | ||||
| var input = inputs[0]; | var input = inputs[0]; | ||||
| @@ -146,6 +146,15 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
| } | } | ||||
| [TestMethod] | |||||
| public void GRU() | |||||
| { | |||||
| var inputs = tf.ones((32, 10, 8)); | |||||
| var gru = tf.keras.layers.GRU(4); | |||||
| var output = gru.Apply(inputs); | |||||
| Assert.AreEqual((32, 4), output.shape); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Bidirectional() | public void Bidirectional() | ||||
| { | { | ||||