Browse Source

Finish SimpleRNNCell and add test

pull/1090/head
Wanglongzhi2001 2 years ago
parent
commit
08b4b89f77
6 changed files with 29 additions and 23 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  3. +5
    -1
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  4. +6
    -3
      src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs
  5. +6
    -12
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  6. +7
    -5
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

+ 3
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -206,7 +206,9 @@ namespace Tensorflow.Keras.Layers
bool use_bias = true, bool use_bias = true,
string kernel_initializer = "glorot_uniform", string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal", string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros");
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f);


public ILayer Subtract(); public ILayer Subtract();
} }


+ 2
- 1
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -4633,8 +4633,9 @@ public static class gen_math_ops
var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatMul", name) { args = new object[] { a, b }, attrs = new Dictionary<string, object>() { ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b } }); var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatMul", name) { args = new object[] { a, b }, attrs = new Dictionary<string, object>() { ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b } });
return _fast_path_result[0]; return _fast_path_result[0];
} }
catch (Exception)
catch (ArgumentException)
{ {
throw new ArgumentException("In[0] and In[1] has diffrent ndims!");
} }
try try
{ {


+ 5
- 1
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -715,7 +715,9 @@ namespace Tensorflow.Keras.Layers
bool use_bias = true, bool use_bias = true,
string kernel_initializer = "glorot_uniform", string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal", string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros")
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f)
=> new SimpleRNNCell(new SimpleRNNArgs => new SimpleRNNCell(new SimpleRNNArgs
{ {
Units = units, Units = units,
@@ -723,6 +725,8 @@ namespace Tensorflow.Keras.Layers
UseBias = use_bias, UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer), KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer), RecurrentInitializer = GetInitializerByName(recurrent_initializer),
Dropout = dropout,
RecurrentDropout = recurrent_dropout
} }
); );




+ 6
- 3
src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs View File

@@ -13,9 +13,10 @@ namespace Tensorflow.Keras.Layers.Rnn
public float dropout; public float dropout;
public float recurrent_dropout; public float recurrent_dropout;
// Get the dropout mask for RNN cell's input. // Get the dropout mask for RNN cell's input.
public Tensors get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{ {

if (dropout == 0f)
return null;
return _generate_dropout_mask( return _generate_dropout_mask(
tf.ones_like(input), tf.ones_like(input),
dropout, dropout,
@@ -24,8 +25,10 @@ namespace Tensorflow.Keras.Layers.Rnn
} }


// Get the recurrent dropout mask for RNN cell. // Get the recurrent dropout mask for RNN cell.
public Tensors get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{ {
if (dropout == 0f)
return null;
return _generate_dropout_mask( return _generate_dropout_mask(
tf.ones_like(input), tf.ones_like(input),
recurrent_dropout, recurrent_dropout,


+ 6
- 12
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -58,10 +58,7 @@ namespace Tensorflow.Keras.Layers.Rnn


protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{ {
Console.WriteLine($"shape of input: {inputs.shape}");
Tensor states = initial_state[0]; Tensor states = initial_state[0];
Console.WriteLine($"shape of initial_state: {states.shape}");

var prev_output = nest.is_nested(states) ? states[0] : states; var prev_output = nest.is_nested(states) ? states[0] : states;
var dp_mask = DRCMixin.get_dropout_maskcell_for_cell(inputs, training.Value); var dp_mask = DRCMixin.get_dropout_maskcell_for_cell(inputs, training.Value);
var rec_dp_mask = DRCMixin.get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); var rec_dp_mask = DRCMixin.get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value);
@@ -72,11 +69,12 @@ namespace Tensorflow.Keras.Layers.Rnn
{ {
if (ranks > 2) if (ranks > 2)
{ {
h = tf.linalg.tensordot(tf.multiply(inputs, dp_mask), kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
// 因为multiply函数会自动添加第一个维度,所以加上下标0
h = tf.linalg.tensordot(math_ops.multiply(inputs, dp_mask)[0], kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
} }
else else
{ {
h = math_ops.matmul(tf.multiply(inputs, dp_mask), kernel.AsTensor());
h = math_ops.matmul(math_ops.multiply(inputs, dp_mask)[0], kernel.AsTensor());
} }
} }
else else
@@ -98,22 +96,18 @@ namespace Tensorflow.Keras.Layers.Rnn


if (rec_dp_mask != null) if (rec_dp_mask != null)
{ {
prev_output = tf.multiply(prev_output, rec_dp_mask);
prev_output = math_ops.multiply(prev_output, rec_dp_mask)[0];
} }


ranks = prev_output.rank; ranks = prev_output.rank;
Console.WriteLine($"shape of h: {h.shape}");

Tensor output; Tensor output;
if (ranks > 2) if (ranks > 2)
{ {
var tmp = tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
output = h + tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } })[0];
output = h + tf.linalg.tensordot(prev_output[0], recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
} }
else else
{ {
output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor())[0];

output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor());
} }
Console.WriteLine($"shape of output: {output.shape}"); Console.WriteLine($"shape of output: {output.shape}");




+ 7
- 5
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -147,13 +147,15 @@ namespace Tensorflow.Keras.UnitTest.Layers
[TestMethod] [TestMethod]
public void SimpleRNNCell() public void SimpleRNNCell()
{ {
var cell = keras.layers.SimpleRNNCell(64, dropout:0.5f, recurrent_dropout:0.5f);
var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
var x = tf.random.normal(new Shape(4, 100));
var cell = keras.layers.SimpleRNNCell(64);
var (y, h1) = cell.Apply(inputs:x, state:h0);
var x = tf.random.normal((4, 100));
var (y, h1) = cell.Apply(inputs: x, state: h0);
// TODO(Wanglongzhi2001),因为SimpleRNNCell需要返回一个Tensor和一个Tensors,只用一个Tensors的话
// hold不住,所以自行在外面将h强制转换成Tensors
var h2 = (Tensors)h1;
Assert.AreEqual((4, 64), y.shape); Assert.AreEqual((4, 64), y.shape);
// this test now cannot pass, need to deal with SimpleRNNCell's Call method
//Assert.AreEqual((4, 64), h1[0].shape);
Assert.AreEqual((4, 64), h2[0].shape);
} }


[TestMethod, Ignore("WIP")] [TestMethod, Ignore("WIP")]


Loading…
Cancel
Save