using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Saving;
namespace Tensorflow.Keras.Layers.Rnn
{
///
/// Cell class for the GRU layer.
///
public class GRUCell : DropoutRNNCellMixin
{
GRUCellArgs _args;
IVariableV1 _kernel;
IVariableV1 _recurrent_kernel;
IInitializer _bias_initializer;
IVariableV1 _bias;
INestStructure _state_size;
INestStructure _output_size;
int Units;
public override INestStructure StateSize => _state_size;
public override INestStructure OutputSize => _output_size;
public override bool SupportOptionalArgs => false;
public GRUCell(GRUCellArgs args) : base(args)
{
_args = args;
if (_args.Units <= 0)
{
throw new ValueError(
$"units must be a positive integer, got {args.Units}");
}
_args.Dropout = Math.Min(1f, Math.Max(0f, _args.Dropout));
_args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
if (_args.RecurrentDropout != 0f && _args.Implementation != 1)
{
Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." +
"Using `implementation=1`.");
_args.Implementation = 1;
}
Units = _args.Units;
_state_size = new NestList(Units);
_output_size = new NestNode(Units);
}
public override void build(KerasShapesWrapper input_shape)
{
//base.build(input_shape);
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];
_kernel = add_weight("kernel", (input_dim, _args.Units * 3),
initializer: _args.KernelInitializer
);
_recurrent_kernel = add_weight("recurrent_kernel", (Units, Units * 3),
initializer: _args.RecurrentInitializer
);
if (_args.UseBias)
{
Shape bias_shape;
if (!_args.ResetAfter)
{
bias_shape = new Shape(3 * Units);
}
else
{
bias_shape = (2, 3 * Units);
}
_bias = add_weight("bias", bias_shape,
initializer: _bias_initializer
);
}
built = true;
}
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var h_tm1 = states.IsNested() ? states[0] : states.Single();
var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 3);
var rec_dp_mask = get_recurrent_dropout_mask_for_cell(h_tm1, training.Value, count: 3);
IVariableV1 input_bias = _bias;
IVariableV1 recurrent_bias = _bias;
if (_args.UseBias)
{
if (!_args.ResetAfter)
{
input_bias = _bias;
recurrent_bias = null;
}
else
{
input_bias = tf.Variable(tf.unstack(_bias.AsTensor())[0]);
recurrent_bias = tf.Variable(tf.unstack(_bias.AsTensor())[1]);
}
}
Tensor hh;
Tensor z;
if ( _args.Implementation == 1)
{
Tensor inputs_z;
Tensor inputs_r;
Tensor inputs_h;
if (0f < _args.Dropout && _args.Dropout < 1f)
{
inputs_z = inputs * dp_mask[0];
inputs_r = inputs * dp_mask[1];
inputs_h = inputs * dp_mask[2];
}
else
{
inputs_z = inputs.Single();
inputs_r = inputs.Single();
inputs_h = inputs.Single();
}
int startIndex = (int)_kernel.AsTensor().shape[0];
var _kernel_slice = tf.slice(_kernel.AsTensor(),
new[] { 0, 0 }, new[] { startIndex, Units });
var x_z = math_ops.matmul(inputs_z, _kernel_slice);
_kernel_slice = tf.slice(_kernel.AsTensor(),
new[] { 0, Units }, new[] { Units, Units});
var x_r = math_ops.matmul(
inputs_r, _kernel_slice);
int endIndex = (int)_kernel.AsTensor().shape[1];
_kernel_slice = tf.slice(_kernel.AsTensor(),
new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 });
var x_h = math_ops.matmul(inputs_h, _kernel_slice);
if(_args.UseBias)
{
x_z = tf.nn.bias_add(
x_z, tf.Variable(input_bias.AsTensor()[$":{Units}"]));
x_r = tf.nn.bias_add(
x_r, tf.Variable(input_bias.AsTensor()[$"{Units}:{Units * 2}"]));
x_h = tf.nn.bias_add(
x_h, tf.Variable(input_bias.AsTensor()[$"{Units * 2}:"]));
}
Tensor h_tm1_z;
Tensor h_tm1_r;
Tensor h_tm1_h;
if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f)
{
h_tm1_z = h_tm1 * rec_dp_mask[0];
h_tm1_r = h_tm1 * rec_dp_mask[1];
h_tm1_h = h_tm1 * rec_dp_mask[2];
}
else
{
h_tm1_z = h_tm1;
h_tm1_r = h_tm1;
h_tm1_h = h_tm1;
}
startIndex = (int)_recurrent_kernel.AsTensor().shape[0];
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
new[] { 0, 0 }, new[] { startIndex, Units });
var recurrent_z = math_ops.matmul(
h_tm1_z, _recurrent_kernel_slice);
_recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
new[] { 0, Units }, new[] { startIndex, Units});
var recurrent_r = math_ops.matmul(
h_tm1_r, _recurrent_kernel_slice);
if(_args.ResetAfter && _args.UseBias)
{
recurrent_z = tf.nn.bias_add(
recurrent_z, tf.Variable(recurrent_bias.AsTensor()[$":{Units}"]));
recurrent_r = tf.nn.bias_add(
recurrent_r, tf.Variable(recurrent_bias.AsTensor()[$"{Units}: {Units * 2}"]));
}
z = _args.RecurrentActivation.Apply(x_z + recurrent_z);
var r = _args.RecurrentActivation.Apply(x_r + recurrent_r);
Tensor recurrent_h;
if (_args.ResetAfter)
{
endIndex = (int)_recurrent_kernel.AsTensor().shape[1];
_recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 });
recurrent_h = math_ops.matmul(
h_tm1_h, _recurrent_kernel_slice);
if(_args.UseBias)
{
recurrent_h = tf.nn.bias_add(
recurrent_h, tf.Variable(recurrent_bias.AsTensor()[$"{Units * 2}:"]));
}
recurrent_h *= r;
}
else
{
_recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 });
recurrent_h = math_ops.matmul(
r * h_tm1_h, _recurrent_kernel_slice);
}
hh = _args.Activation.Apply(x_h + recurrent_h);
}
else
{
if (0f < _args.Dropout && _args.Dropout < 1f)
{
inputs = inputs * dp_mask[0];
}
var matrix_x = math_ops.matmul(inputs, _kernel.AsTensor());
if(_args.UseBias)
{
matrix_x = tf.nn.bias_add(matrix_x, input_bias);
}
var matrix_x_spilted = tf.split(matrix_x, 3, axis: -1);
var x_z = matrix_x_spilted[0];
var x_r = matrix_x_spilted[1];
var x_h = matrix_x_spilted[2];
Tensor matrix_inner;
if (_args.ResetAfter)
{
matrix_inner = math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor());
if ( _args.UseBias)
{
matrix_inner = tf.nn.bias_add(
matrix_inner, recurrent_bias);
}
}
else
{
var startIndex = (int)_recurrent_kernel.AsTensor().shape[0];
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
new[] { 0, 0 }, new[] { startIndex, Units * 2 });
matrix_inner = math_ops.matmul(
h_tm1, _recurrent_kernel_slice);
}
var matrix_inner_splitted = tf.split(matrix_inner, new int[] {Units, Units, -1}, axis:-1);
var recurrent_z = matrix_inner_splitted[0];
var recurrent_r = matrix_inner_splitted[0];
var recurrent_h = matrix_inner_splitted[0];
z = _args.RecurrentActivation.Apply(x_z + recurrent_z);
var r = _args.RecurrentActivation.Apply(x_r + recurrent_r);
if(_args.ResetAfter)
{
recurrent_h = r * recurrent_h;
}
else
{
var startIndex = (int)_recurrent_kernel.AsTensor().shape[0];
var endIndex = (int)_recurrent_kernel.AsTensor().shape[1];
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
new[] { 0, 2*Units }, new[] { startIndex, endIndex - 2 * Units });
recurrent_h = math_ops.matmul(
r * h_tm1, _recurrent_kernel_slice);
}
hh = _args.Activation.Apply(x_h + recurrent_h);
}
var h = z * h_tm1 + (1 - z) * hh;
if (states.IsNested())
{
var new_state = new NestList(h);
return new Nest(new INestStructure[] { new NestNode(h), new_state }).ToTensors();
}
else
{
return new Nest(new INestStructure[] { new NestNode(h), new NestNode(h)}).ToTensors();
}
}
}
}