diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index e94c8bf1..2f92c4e5 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -15,7 +15,7 @@ namespace Tensorflow.Keras List Layers { get; } List InboundNodes { get; } List OutboundNodes { get; } - Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null); + Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null); List TrainableVariables { get; } List TrainableWeights { get; } List NonTrainableWeights { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index e488c47e..4e99731f 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -145,7 +145,7 @@ namespace Tensorflow throw new NotImplementedException("_zero_state_tensors"); } - public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null) + public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null) { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs index d52190fd..8a66948b 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null) + public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null) { if (callContext.Value == null) callContext.Value = new CallContext(); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 76c592ad..de57f19a 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -142,6 +142,7 @@ namespace Tensorflow.Keras.Engine int verbose = 1, List callbacks = null, IDatasetV2 validation_data = null, + int validation_step = 10, // 间隔多少次会进行一次验证 bool shuffle = true, int initial_epoch = 0, int max_queue_size = 10, @@ -164,11 +165,11 @@ namespace Tensorflow.Keras.Engine }); - return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, + return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data, train_step_func: train_step_function); } - History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, IDatasetV2 validation_data, + History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List callbackList, IDatasetV2 validation_data, Func> train_step_func) { stop_training = false; @@ -207,6 +208,9 @@ namespace Tensorflow.Keras.Engine if (validation_data != null) { + if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) + continue; + var val_logs = evaluate(validation_data); foreach(var log in val_logs) { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index f86de8a8..0ca62c39 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -393,7 +393,7 @@ namespace Tensorflow.Keras.Layers.Rnn } } - public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null) + public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool? training = false, IOptionalArgs? optional_args = null) { RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; if (optional_args is not null && rnn_optional_args is null)