diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs index 75d5d021..624756af 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs @@ -3,7 +3,7 @@ using System; using System.Collections.Generic; using System.Text; -namespace Tensorflow.Keras.ArgsDefinition.Rnn +namespace Tensorflow.Keras.ArgsDefinition { public class GRUCellArgs : AutoSerializeLayerArgs { diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs index db76fda0..d816b0ff 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs @@ -1,4 +1,4 @@ -namespace Tensorflow.Keras.ArgsDefinition.Rnn +namespace Tensorflow.Keras.ArgsDefinition { public class LSTMArgs : RNNArgs { diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs index 786236e4..f4503231 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs @@ -1,7 +1,7 @@ using Newtonsoft.Json; using static Tensorflow.Binding; -namespace Tensorflow.Keras.ArgsDefinition.Rnn +namespace Tensorflow.Keras.ArgsDefinition { // TODO: complete the implementation public class LSTMCellArgs : AutoSerializeLayerArgs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs index 2d7fb001..b84d30d3 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs @@ -1,8 +1,8 @@ using Newtonsoft.Json; using System.Collections.Generic; -using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Layers; -namespace Tensorflow.Keras.ArgsDefinition.Rnn +namespace Tensorflow.Keras.ArgsDefinition { // TODO(Rinne): add regularizers. public class RNNArgs : AutoSerializeLayerArgs @@ -23,16 +23,22 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn public int? InputDim { get; set; } public int? InputLength { get; set; } // TODO: Add `num_constants` and `zero_output_for_mask`. - + [JsonProperty("units")] public int Units { get; set; } + [JsonProperty("activation")] public Activation Activation { get; set; } + [JsonProperty("recurrent_activation")] public Activation RecurrentActivation { get; set; } + [JsonProperty("use_bias")] public bool UseBias { get; set; } = true; public IInitializer KernelInitializer { get; set; } public IInitializer RecurrentInitializer { get; set; } public IInitializer BiasInitializer { get; set; } + [JsonProperty("dropout")] public float Dropout { get; set; } = .0f; + [JsonProperty("zero_output_for_mask")] public bool ZeroOutputForMask { get; set; } = false; + [JsonProperty("recurrent_dropout")] public float RecurrentDropout { get; set; } = .0f; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs index 64b500bb..a6520589 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Common.Types; -namespace Tensorflow.Keras.ArgsDefinition.Rnn +namespace Tensorflow.Keras.ArgsDefinition { public class RnnOptionalArgs: IOptionalArgs { diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs index fcfd694d..e45ef79d 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs @@ -1,4 +1,4 @@ -namespace Tensorflow.Keras.ArgsDefinition.Rnn +namespace Tensorflow.Keras.ArgsDefinition { public class SimpleRNNArgs : RNNArgs { diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs index d21d6190..b84ea21b 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs @@ -1,6 +1,6 @@ using Newtonsoft.Json; -namespace Tensorflow.Keras.ArgsDefinition.Rnn +namespace Tensorflow.Keras.ArgsDefinition { public class SimpleRNNCellArgs: AutoSerializeLayerArgs { diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs index 50a6127d..2600f14e 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs @@ -1,7 +1,7 @@ using System.Collections.Generic; -using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Layers; -namespace Tensorflow.Keras.ArgsDefinition.Rnn +namespace Tensorflow.Keras.ArgsDefinition { public class StackedRNNCellsArgs : LayerArgs { diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index b48cd553..1670f9d1 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -1,7 +1,7 @@ using System; using Tensorflow.Framework.Models; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Layers; using Tensorflow.NumPy; using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs index 8d6fbc97..43df75b1 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Common.Types; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { public interface IRnnCell: ILayer { diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs index e73244a5..8cf6150d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Text; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { public interface IStackedRnnCells : IRnnCell { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 4e99731f..9905d39c 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -19,9 +19,8 @@ using System.Collections.Generic; using Tensorflow.Common.Types; using Tensorflow.Keras; using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Layers; using Tensorflow.Keras.Saving; using Tensorflow.NumPy; using Tensorflow.Operations; diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 2dc46329..c624c990 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -571,7 +571,9 @@ namespace Tensorflow if (tf.Context.executing_eagerly()) return true; else - throw new NotImplementedException(""); + // TODO(Wanglongzhi2001), implement the false case + return true; + //throw new NotImplementedException(""); } public static bool inside_function() diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 5968461d..cb85bbba 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -2,9 +2,8 @@ using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Core; -using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Layers; using Tensorflow.NumPy; using static Tensorflow.Binding; using static Tensorflow.KerasApi; diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs index 75feb8ea..27c13f34 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs @@ -6,7 +6,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { public abstract class DropoutRNNCellMixin: Layer, IRnnCell { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs index 02fe54f4..2b9c01e3 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs @@ -3,12 +3,11 @@ using System.Collections.Generic; using System.Diagnostics; using System.Text; using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Common.Extensions; using Tensorflow.Common.Types; using Tensorflow.Keras.Saving; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { /// /// Cell class for the GRU layer. diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs index 025465fd..b5d58324 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs @@ -1,10 +1,10 @@ using System.Linq; -using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Common.Types; using Tensorflow.Common.Extensions; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { /// /// Long Short-Term Memory layer - Hochreiter 1997. diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs index 284a2b77..e4fc6dd2 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs @@ -3,12 +3,12 @@ using Serilog.Core; using System.Diagnostics; using Tensorflow.Common.Extensions; using Tensorflow.Common.Types; -using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { /// /// Cell class for the LSTM layer. diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 6075547b..0e81d20e 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Reflection; using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.Util; @@ -14,7 +13,7 @@ using Tensorflow.Common.Types; using System.Runtime.CompilerServices; // from tensorflow.python.distribute import distribution_strategy_context as ds_context; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { /// /// Base class for recurrent layers. @@ -185,6 +184,7 @@ namespace Tensorflow.Keras.Layers.Rnn public override void build(KerasShapesWrapper input_shape) { + _buildInputShape = input_shape; input_shape = new KerasShapesWrapper(input_shape.Shapes[0]); InputSpec get_input_spec(Shape shape) diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs index 018b1778..1419da4b 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs @@ -4,7 +4,7 @@ using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { public abstract class RnnBase: Layer { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs index a22f31c7..9c199eb4 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -1,11 +1,11 @@ using System.Data; -using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving; using Tensorflow.Operations.Activation; using static HDF.PInvoke.H5Z; using static Tensorflow.ApiDef.Types; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { public class SimpleRNN : RNN { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index c77f7779..e74b5692 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -1,7 +1,7 @@ using System; using System.Collections.Generic; using System.Text; -using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.Common.Types; @@ -9,7 +9,7 @@ using Tensorflow.Common.Extensions; using Tensorflow.Keras.Utils; using Tensorflow.Graphs; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { /// /// Cell class for SimpleRNN. diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index 8799bfb2..ece2bc5b 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -3,12 +3,12 @@ using System.ComponentModel; using System.Linq; using Tensorflow.Common.Extensions; using Tensorflow.Common.Types; -using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; -namespace Tensorflow.Keras.Layers.Rnn +namespace Tensorflow.Keras.Layers { public class StackedRNNCells : Layer, IRnnCell { diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index fd1453d3..0bd816cc 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -13,7 +13,6 @@ using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; -using Tensorflow.Keras.Layers.Rnn; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Saving.SavedModel; diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index 0ec5d1a8..325d3327 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Layers; using Tensorflow.Keras.Metrics; using Tensorflow.Train; diff --git a/src/TensorFlowNET.Keras/Utils/RnnUtils.cs b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs index e8700c1f..1e9f6d84 100644 --- a/src/TensorFlowNET.Keras/Utils/RnnUtils.cs +++ b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Text; using Tensorflow.Common.Types; -using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Layers; using Tensorflow.Common.Extensions; namespace Tensorflow.Keras.Utils diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs index becdbcd6..5f7bd574 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -6,7 +6,7 @@ using System.Text; using System.Threading.Tasks; using Tensorflow.Common.Types; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Layers; using Tensorflow.Keras.Saving; using Tensorflow.NumPy; using Tensorflow.Train; diff --git a/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs b/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs index 10db2bd1..382941d9 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs @@ -1,5 +1,7 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestPlatform.Utilities; +using Microsoft.VisualStudio.TestTools.UnitTesting; using System.Linq; +using Tensorflow.Keras.Engine; using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.UnitTest.Helpers; using Tensorflow.NumPy; @@ -79,6 +81,31 @@ public class ModelLoadTest model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); } + [TestMethod] + public void LSTMLoad() + { + var inputs = np.random.randn(10, 5, 3); + var outputs = np.random.randn(10, 1); + var model = keras.Sequential(); + model.add(keras.Input(shape: (5, 3))); + var lstm = keras.layers.LSTM(32); + + model.add(lstm); + + model.add(keras.layers.Dense(1, keras.activations.Sigmoid)); + + model.compile(optimizer: keras.optimizers.Adam(), + loss: keras.losses.MeanSquaredError(), + new[] { "accuracy" }); + + var result = model.fit(inputs.numpy(), outputs.numpy(), batch_size: 10, epochs: 3, workers: 16, use_multiprocessing: true); + + model.save("LSTM_Random"); + + var model_loaded = keras.models.load_model("LSTM_Random"); + model_loaded.summary(); + } + [Ignore] [TestMethod] public void VGG19()