From 321ddfc13ec7c91d2b8e4fea0ad9a7662dd30899 Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Sun, 15 Jan 2023 21:18:05 -0600 Subject: [PATCH] Fix Model.build. --- src/TensorFlowNET.Console/SimpleRnnTest.cs | 20 ++++++-------- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 1 + .../Keras/Layers/ILayersApi.cs | 4 ++- .../Operations/Initializers/Orthogonal.cs | 22 +++++++++++++++- .../Operations/NnOps/RNNCell.cs | 5 ++++ src/TensorFlowNET.Core/tensorflow.cs | 5 ++++ src/TensorFlowNET.Keras/Engine/Functional.cs | 7 ++++- src/TensorFlowNET.Keras/Engine/Model.Build.cs | 9 ++++--- src/TensorFlowNET.Keras/Engine/Sequential.cs | 8 +----- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 10 ++++--- src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 10 ++++++- .../Layers/Rnn/SimpleRNN.cs | 14 +--------- .../Layers/Rnn/SimpleRNNCell.cs | 26 +++++++++++++++++-- .../Layers/LayersTest.cs | 9 ++++--- .../Tensorflow.Binding.UnitTest.csproj | 2 +- 15 files changed, 104 insertions(+), 48 deletions(-) diff --git a/src/TensorFlowNET.Console/SimpleRnnTest.cs b/src/TensorFlowNET.Console/SimpleRnnTest.cs index b61cee9c..da124517 100644 --- a/src/TensorFlowNET.Console/SimpleRnnTest.cs +++ b/src/TensorFlowNET.Console/SimpleRnnTest.cs @@ -12,20 +12,16 @@ namespace Tensorflow { public void Run() { - tf.keras = new KerasInterface(); - var inputs = np.random.random((32, 10, 8)).astype(np.float32); - var simple_rnn = tf.keras.layers.SimpleRNN(4); - var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`. - if (output.shape == (32, 4)) - { + tf.UseKeras(); + var inputs = np.random.random((6, 10, 8)).astype(np.float32); + //var simple_rnn = tf.keras.layers.SimpleRNN(4); + //var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`. - } - /*simple_rnn = tf.keras.layers.SimpleRNN( - 4, return_sequences = True, return_state = True) + var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); - # whole_sequence_output has shape `[32, 10, 4]`. - # final_state has shape `[32, 4]`. - whole_sequence_output, final_state = simple_rnn(inputs)*/ + // whole_sequence_output has shape `[32, 10, 4]`. + // final_state has shape `[32, 4]`. + var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f77b4a86..1ec4a2c6 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -9,6 +9,7 @@ namespace Tensorflow.Keras string Name { get; } bool Trainable { get; } bool Built { get; } + void build(Shape input_shape); List Layers { get; } List InboundNodes { get; } List OutboundNodes { get; } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 3f4d1ed8..525bfd35 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -163,7 +163,9 @@ namespace Tensorflow.Keras.Layers string activation = "tanh", string kernel_initializer = "glorot_uniform", string recurrent_initializer = "orthogonal", - string bias_initializer = "zeros"); + string bias_initializer = "zeros", + bool return_sequences = false, + bool return_state = false); public ILayer Subtract(); } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs index 254a7ee7..90f3f93c 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs @@ -1,12 +1,32 @@ using System; +using System.Linq; +using static Tensorflow.TensorShapeProto.Types; namespace Tensorflow.Operations.Initializers { public class Orthogonal : IInitializer { + float _gain = 0f; + + public Orthogonal(float gain = 1.0f, int? seed = null) + { + + } + public Tensor Apply(InitializerArgs args) { - throw new NotImplementedException(); + return _generate_init_val(args.Shape, args.DType); + } + + private Tensor _generate_init_val(Shape shape, TF_DataType dtype) + { + var num_rows = 1L; + foreach (var dim in shape.dims.Take(shape.ndim - 1)) + num_rows *= dim; + var num_cols = shape.dims.Last(); + var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows)); + + throw new NotImplementedException(""); } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 04fdc7e5..d63d0311 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -147,5 +147,10 @@ namespace Tensorflow { throw new NotImplementedException(); } + + public void build(Shape input_shape) + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index e02723b7..35762be1 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -65,6 +65,11 @@ namespace Tensorflow InitGradientEnvironment(); } + public void UseKeras() where T : IKerasApi, new() + { + keras = new T(); + } + public string VERSION => c_api.StringPiece(c_api.TF_Version()); private void InitGradientEnvironment() diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 09a31b94..d10ed214 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -65,7 +65,12 @@ namespace Tensorflow.Keras.Engine } // Keep track of the network's nodes and layers. - (NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs); + (NetworkNodes, NodesByDepth, var layers, _) = MapGraphNetwork(inputs, outputs); + + if (!_self_tracked_trackables.Any()) + { + _self_tracked_trackables = layers; + } // Build self.input_names and self.output_names. _set_output_names(); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Build.cs b/src/TensorFlowNET.Keras/Engine/Model.Build.cs index 1e0a880a..a51b9434 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Build.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Build.cs @@ -1,9 +1,6 @@ using System; using System.Linq; using Tensorflow.Graphs; -using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.Losses; -using Tensorflow.Keras.Optimizers; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -13,6 +10,12 @@ namespace Tensorflow.Keras.Engine { public override void build(Shape input_shape) { + if (this is Functional || this is Sequential) + { + base.build(input_shape); + return; + } + var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); graph.as_default(); diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 681ab2f0..b4d1ecfe 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -122,15 +122,9 @@ namespace Tensorflow.Keras.Engine else { _self_tracked_trackables.add(layer); - _handle_deferred_layer_dependencies(layer); } } - void _handle_deferred_layer_dependencies(params ILayer[] layers) - { - _self_tracked_trackables.AddRange(layers); - } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { if (!_has_explicit_input_shape) @@ -156,7 +150,7 @@ namespace Tensorflow.Keras.Engine ops.init_scope(); var inputs = keras.Input(batch_input_shape: input_shape, dtype: input_dtype, - name: $"{_self_tracked_trackables[0].Name}_input"); + name: _self_tracked_trackables[0].Name.EndsWith("_input") ? _self_tracked_trackables[0].Name : $"{_self_tracked_trackables[0].Name}_input"); Tensors layer_input = inputs; Tensors layer_output = null; Tensors outputs = null; diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 50c66be7..5c1c8995 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -658,14 +658,18 @@ namespace Tensorflow.Keras.Layers string activation = "tanh", string kernel_initializer = "glorot_uniform", string recurrent_initializer = "orthogonal", - string bias_initializer = "zeros") + string bias_initializer = "zeros", + bool return_sequences = false, + bool return_state = false) => new SimpleRNN(new SimpleRNNArgs { Units = units, Activation = GetActivationByName(activation), KernelInitializer = GetInitializerByName(kernel_initializer), - RecurrentInitializer= GetInitializerByName(recurrent_initializer), - BiasInitializer= GetInitializerByName(bias_initializer) + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + ReturnSequences = return_sequences, + ReturnState = return_state }); /// diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index c2b86ae4..f894f41f 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers.Rnn private int _num_constants = 0; protected IVariableV1 kernel; protected IVariableV1 bias; - + protected ILayer cell; public RNN(RNNArgs args) : base(PreConstruct(args)) { this.args = args; @@ -37,6 +37,14 @@ namespace Tensorflow.Keras.Layers.Rnn //} } + public override void build(Shape input_shape) + { + if (!cell.Built) + { + cell.build(input_shape); + } + } + private static RNNArgs PreConstruct(RNNArgs args) { if (args.Kwargs == null) diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs index c8366ff4..a3cd002d 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -9,22 +9,10 @@ namespace Tensorflow.Keras.Layers.Rnn public class SimpleRNN : RNN { SimpleRNNArgs args; - SimpleRNNCell cell; public SimpleRNN(SimpleRNNArgs args) : base(args) { this.args = args; - } - - public override void build(Shape input_shape) - { - var input_dim = input_shape[-1]; - - kernel = add_weight("kernel", (input_shape[-1], args.Units), - initializer: args.KernelInitializer - //regularizer = self.kernel_regularizer, - //constraint = self.kernel_constraint, - //caching_device = default_caching_device, - ); + cell = new SimpleRNNCell(args); } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 10b28e76..8d696d16 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -8,14 +8,36 @@ namespace Tensorflow.Keras.Layers.Rnn { public class SimpleRNNCell : Layer { + SimpleRNNArgs args; + IVariableV1 kernel; + IVariableV1 recurrent_kernel; + IVariableV1 bias; + public SimpleRNNCell(SimpleRNNArgs args) : base(args) { - + this.args = args; } public override void build(Shape input_shape) { - + var input_dim = input_shape[-1]; + + kernel = add_weight("kernel", (input_shape[-1], args.Units), + initializer: args.KernelInitializer + ); + + recurrent_kernel = add_weight("recurrent_kernel", (args.Units, args.Units), + initializer: args.RecurrentInitializer + ); + + if (args.UseBias) + { + bias = add_weight("bias", (args.Units), + initializer: args.RecurrentInitializer + ); + } + + built = true; } } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index f4fdf94a..d4ac4b90 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -150,10 +150,13 @@ namespace TensorFlowNET.Keras.UnitTest [TestMethod] public void SimpleRNN() { - var inputs = np.random.random((32, 10, 8)).astype(np.float32); - var simple_rnn = keras.layers.SimpleRNN(4); + tf.UseKeras(); + var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); + /*var simple_rnn = keras.layers.SimpleRNN(4); var output = simple_rnn.Apply(inputs); - Assert.AreEqual((32, 4), output.shape); + Assert.AreEqual((32, 4), output.shape);*/ + var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); + var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); } [TestMethod] diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj index 36ff4a3d..56c212d0 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj @@ -47,7 +47,7 @@ - +