Browse Source

Fix Model.build.

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
321ddfc13e
15 changed files with 104 additions and 48 deletions
  1. +8
    -12
      src/TensorFlowNET.Console/SimpleRnnTest.cs
  2. +1
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  3. +3
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  4. +21
    -1
      src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
  5. +5
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  6. +5
    -0
      src/TensorFlowNET.Core/tensorflow.cs
  7. +6
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  8. +6
    -3
      src/TensorFlowNET.Keras/Engine/Model.Build.cs
  9. +1
    -7
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  10. +7
    -3
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  11. +9
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  12. +1
    -13
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
  13. +24
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  14. +6
    -3
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
  15. +1
    -1
      test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj

+ 8
- 12
src/TensorFlowNET.Console/SimpleRnnTest.cs View File

@@ -12,20 +12,16 @@ namespace Tensorflow
{
public void Run()
{
tf.keras = new KerasInterface();
var inputs = np.random.random((32, 10, 8)).astype(np.float32);
var simple_rnn = tf.keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.
if (output.shape == (32, 4))
{
tf.UseKeras<KerasInterface>();
var inputs = np.random.random((6, 10, 8)).astype(np.float32);
//var simple_rnn = tf.keras.layers.SimpleRNN(4);
//var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.

}
/*simple_rnn = tf.keras.layers.SimpleRNN(
4, return_sequences = True, return_state = True)
var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);

# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = simple_rnn(inputs)*/
// whole_sequence_output has shape `[32, 10, 4]`.
// final_state has shape `[32, 4]`.
var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -9,6 +9,7 @@ namespace Tensorflow.Keras
string Name { get; }
bool Trainable { get; }
bool Built { get; }
void build(Shape input_shape);
List<ILayer> Layers { get; }
List<INode> InboundNodes { get; }
List<INode> OutboundNodes { get; }


+ 3
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -163,7 +163,9 @@ namespace Tensorflow.Keras.Layers
string activation = "tanh",
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros");
string bias_initializer = "zeros",
bool return_sequences = false,
bool return_state = false);

public ILayer Subtract();
}


+ 21
- 1
src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs View File

@@ -1,12 +1,32 @@
using System;
using System.Linq;
using static Tensorflow.TensorShapeProto.Types;

namespace Tensorflow.Operations.Initializers
{
public class Orthogonal : IInitializer
{
float _gain = 0f;

public Orthogonal(float gain = 1.0f, int? seed = null)
{

}

public Tensor Apply(InitializerArgs args)
{
throw new NotImplementedException();
return _generate_init_val(args.Shape, args.DType);
}

private Tensor _generate_init_val(Shape shape, TF_DataType dtype)
{
var num_rows = 1L;
foreach (var dim in shape.dims.Take(shape.ndim - 1))
num_rows *= dim;
var num_cols = shape.dims.Last();
var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows));

throw new NotImplementedException("");
}
}
}

+ 5
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -147,5 +147,10 @@ namespace Tensorflow
{
throw new NotImplementedException();
}

public void build(Shape input_shape)
{
throw new NotImplementedException();
}
}
}

+ 5
- 0
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -65,6 +65,11 @@ namespace Tensorflow
InitGradientEnvironment();
}

public void UseKeras<T>() where T : IKerasApi, new()
{
keras = new T();
}

public string VERSION => c_api.StringPiece(c_api.TF_Version());

private void InitGradientEnvironment()


+ 6
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -65,7 +65,12 @@ namespace Tensorflow.Keras.Engine
}

// Keep track of the network's nodes and layers.
(NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs);
(NetworkNodes, NodesByDepth, var layers, _) = MapGraphNetwork(inputs, outputs);

if (!_self_tracked_trackables.Any())
{
_self_tracked_trackables = layers;
}

// Build self.input_names and self.output_names.
_set_output_names();


+ 6
- 3
src/TensorFlowNET.Keras/Engine/Model.Build.cs View File

@@ -1,9 +1,6 @@
using System;
using System.Linq;
using Tensorflow.Graphs;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

@@ -13,6 +10,12 @@ namespace Tensorflow.Keras.Engine
{
public override void build(Shape input_shape)
{
if (this is Functional || this is Sequential)
{
base.build(input_shape);
return;
}

var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph();

graph.as_default();


+ 1
- 7
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -122,15 +122,9 @@ namespace Tensorflow.Keras.Engine
else
{
_self_tracked_trackables.add(layer);
_handle_deferred_layer_dependencies(layer);
}
}

void _handle_deferred_layer_dependencies(params ILayer[] layers)
{
_self_tracked_trackables.AddRange(layers);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (!_has_explicit_input_shape)
@@ -156,7 +150,7 @@ namespace Tensorflow.Keras.Engine
ops.init_scope();
var inputs = keras.Input(batch_input_shape: input_shape,
dtype: input_dtype,
name: $"{_self_tracked_trackables[0].Name}_input");
name: _self_tracked_trackables[0].Name.EndsWith("_input") ? _self_tracked_trackables[0].Name : $"{_self_tracked_trackables[0].Name}_input");
Tensors layer_input = inputs;
Tensors layer_output = null;
Tensors outputs = null;


+ 7
- 3
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -658,14 +658,18 @@ namespace Tensorflow.Keras.Layers
string activation = "tanh",
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros")
string bias_initializer = "zeros",
bool return_sequences = false,
bool return_state = false)
=> new SimpleRNN(new SimpleRNNArgs
{
Units = units,
Activation = GetActivationByName(activation),
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer= GetInitializerByName(recurrent_initializer),
BiasInitializer= GetInitializerByName(bias_initializer)
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
ReturnSequences = return_sequences,
ReturnState = return_state
});

/// <summary>


+ 9
- 1
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers.Rnn
private int _num_constants = 0;
protected IVariableV1 kernel;
protected IVariableV1 bias;
protected ILayer cell;
public RNN(RNNArgs args) : base(PreConstruct(args))
{
this.args = args;
@@ -37,6 +37,14 @@ namespace Tensorflow.Keras.Layers.Rnn
//}
}

public override void build(Shape input_shape)
{
if (!cell.Built)
{
cell.build(input_shape);
}
}

private static RNNArgs PreConstruct(RNNArgs args)
{
if (args.Kwargs == null)


+ 1
- 13
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs View File

@@ -9,22 +9,10 @@ namespace Tensorflow.Keras.Layers.Rnn
public class SimpleRNN : RNN
{
SimpleRNNArgs args;
SimpleRNNCell cell;
public SimpleRNN(SimpleRNNArgs args) : base(args)
{
this.args = args;
}

public override void build(Shape input_shape)
{
var input_dim = input_shape[-1];

kernel = add_weight("kernel", (input_shape[-1], args.Units),
initializer: args.KernelInitializer
//regularizer = self.kernel_regularizer,
//constraint = self.kernel_constraint,
//caching_device = default_caching_device,
);
cell = new SimpleRNNCell(args);
}
}
}

+ 24
- 2
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -8,14 +8,36 @@ namespace Tensorflow.Keras.Layers.Rnn
{
public class SimpleRNNCell : Layer
{
SimpleRNNArgs args;
IVariableV1 kernel;
IVariableV1 recurrent_kernel;
IVariableV1 bias;

public SimpleRNNCell(SimpleRNNArgs args) : base(args)
{

this.args = args;
}

public override void build(Shape input_shape)
{
var input_dim = input_shape[-1];

kernel = add_weight("kernel", (input_shape[-1], args.Units),
initializer: args.KernelInitializer
);

recurrent_kernel = add_weight("recurrent_kernel", (args.Units, args.Units),
initializer: args.RecurrentInitializer
);

if (args.UseBias)
{
bias = add_weight("bias", (args.Units),
initializer: args.RecurrentInitializer
);
}

built = true;
}
}
}

+ 6
- 3
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -150,10 +150,13 @@ namespace TensorFlowNET.Keras.UnitTest
[TestMethod]
public void SimpleRNN()
{
var inputs = np.random.random((32, 10, 8)).astype(np.float32);
var simple_rnn = keras.layers.SimpleRNN(4);
tf.UseKeras<KerasInterface>();
var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
/*var simple_rnn = keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs);
Assert.AreEqual((32, 4), output.shape);
Assert.AreEqual((32, 4), output.shape);*/
var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
}

[TestMethod]


+ 1
- 1
test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj View File

@@ -47,7 +47,7 @@

<ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.144" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.0.0" />
<PackageReference Include="MSTest.TestAdapter" Version="2.2.8" />
<PackageReference Include="MSTest.TestFramework" Version="2.2.8" />


Loading…
Cancel
Save