the bool to tensor has a bug, if in init the training is False, the program not start.tags/v0.110.4-Transformer-Model
| @@ -15,7 +15,7 @@ namespace Tensorflow.Keras | |||||
| List<ILayer> Layers { get; } | List<ILayer> Layers { get; } | ||||
| List<INode> InboundNodes { get; } | List<INode> InboundNodes { get; } | ||||
| List<INode> OutboundNodes { get; } | List<INode> OutboundNodes { get; } | ||||
| Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null); | |||||
| Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null); | |||||
| List<IVariableV1> TrainableVariables { get; } | List<IVariableV1> TrainableVariables { get; } | ||||
| List<IVariableV1> TrainableWeights { get; } | List<IVariableV1> TrainableWeights { get; } | ||||
| List<IVariableV1> NonTrainableWeights { get; } | List<IVariableV1> NonTrainableWeights { get; } | ||||
| @@ -145,7 +145,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("_zero_state_tensors"); | throw new NotImplementedException("_zero_state_tensors"); | ||||
| } | } | ||||
| public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null) | |||||
| public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null) | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| @@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// <param name="state"></param> | /// <param name="state"></param> | ||||
| /// <param name="training"></param> | /// <param name="training"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null) | |||||
| public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null) | |||||
| { | { | ||||
| if (callContext.Value == null) | if (callContext.Value == null) | ||||
| callContext.Value = new CallContext(); | callContext.Value = new CallContext(); | ||||
| @@ -142,6 +142,7 @@ namespace Tensorflow.Keras.Engine | |||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
| IDatasetV2 validation_data = null, | IDatasetV2 validation_data = null, | ||||
| int validation_step = 10, // 间隔多少次会进行一次验证 | |||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| int max_queue_size = 10, | int max_queue_size = 10, | ||||
| @@ -164,11 +165,11 @@ namespace Tensorflow.Keras.Engine | |||||
| }); | }); | ||||
| return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, | |||||
| return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data, | |||||
| train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
| } | } | ||||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data, | |||||
| History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data, | |||||
| Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | ||||
| { | { | ||||
| stop_training = false; | stop_training = false; | ||||
| @@ -207,6 +208,9 @@ namespace Tensorflow.Keras.Engine | |||||
| if (validation_data != null) | if (validation_data != null) | ||||
| { | { | ||||
| if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) | |||||
| continue; | |||||
| var val_logs = evaluate(validation_data); | var val_logs = evaluate(validation_data); | ||||
| foreach(var log in val_logs) | foreach(var log in val_logs) | ||||
| { | { | ||||
| @@ -393,7 +393,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| } | } | ||||
| } | } | ||||
| public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null) | |||||
| public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool? training = false, IOptionalArgs? optional_args = null) | |||||
| { | { | ||||
| RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; | RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; | ||||
| if (optional_args is not null && rnn_optional_args is null) | if (optional_args is not null && rnn_optional_args is null) | ||||