Add EarlyStopping callbacktags/v0.100.5-BERT-load
| @@ -4,6 +4,7 @@ public interface ICallback | |||
| { | |||
| Dictionary<string, List<float>> history { get; set; } | |||
| void on_train_begin(); | |||
| void on_train_end(); | |||
| void on_epoch_begin(int epoch); | |||
| void on_train_batch_begin(long step); | |||
| void on_train_batch_end(long end_step, Dictionary<string, float> logs); | |||
| @@ -17,6 +17,7 @@ public interface IModel : ILayer | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| bool shuffle = true, | |||
| int initial_epoch = 0, | |||
| @@ -28,6 +29,7 @@ public interface IModel : ILayer | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| bool shuffle = true, | |||
| int initial_epoch = 0, | |||
| @@ -73,4 +75,6 @@ public interface IModel : ILayer | |||
| void summary(int line_length = -1, float[] positions = null); | |||
| IKerasConfig get_config(); | |||
| void set_stopTraining_true(); | |||
| } | |||
| @@ -7,7 +7,8 @@ namespace Tensorflow.Keras.Callbacks; | |||
| public class CallbackList | |||
| { | |||
| List<ICallback> callbacks = new List<ICallback>(); | |||
| // 改成public使得新定义的callback可以加入到callbacks里 | |||
| public List<ICallback> callbacks = new List<ICallback>(); | |||
| public History History => callbacks[0] as History; | |||
| public CallbackList(CallbackParams parameters) | |||
| @@ -66,7 +67,7 @@ public class CallbackList | |||
| public void on_test_batch_begin(long step) | |||
| { | |||
| callbacks.ForEach(x => x.on_train_batch_begin(step)); | |||
| callbacks.ForEach(x => x.on_test_batch_begin(step)); | |||
| } | |||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||
| { | |||
| @@ -0,0 +1,155 @@ | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.Callbacks; | |||
| /// <summary> | |||
| /// Stop training when a monitored metric has stopped improving. | |||
| /// </summary> | |||
| /// <param name="parameters"></param> | |||
| /// <param name="monitor"></param> | |||
| public class EarlyStopping: ICallback | |||
| { | |||
| int _paitence; | |||
| int _min_delta; | |||
| int _verbose; | |||
| int _stopped_epoch; | |||
| int _wait; | |||
| int _best_epoch; | |||
| int _start_from_epoch; | |||
| float _best; | |||
| float _baseline; | |||
| string _monitor; | |||
| string _mode; | |||
| bool _restore_best_weights; | |||
| List<IVariableV1>? _best_weights; | |||
| CallbackParams _parameters; | |||
| public Dictionary<string, List<float>>? history { get; set; } | |||
| // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model | |||
| public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", int min_delta = 0, int patience = 0, | |||
| int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false, | |||
| int start_from_epoch = 0) | |||
| { | |||
| _parameters = parameters; | |||
| _stopped_epoch = 0; | |||
| _wait = 0; | |||
| _monitor = monitor; | |||
| _paitence = patience; | |||
| _verbose = verbose; | |||
| _baseline = baseline; | |||
| _start_from_epoch = start_from_epoch; | |||
| _min_delta = Math.Abs(min_delta); | |||
| _restore_best_weights = restore_best_weights; | |||
| _mode = mode; | |||
| if (mode != "auto" && mode != "min" && mode != "max") | |||
| { | |||
| Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode); | |||
| } | |||
| } | |||
| public void on_train_begin() | |||
| { | |||
| _wait = 0; | |||
| _stopped_epoch = 0; | |||
| _best_epoch = 0; | |||
| _best = (float)np.Inf; | |||
| } | |||
| public void on_epoch_begin(int epoch) | |||
| { | |||
| } | |||
| public void on_train_batch_begin(long step) | |||
| { | |||
| } | |||
| public void on_train_batch_end(long end_step, Dictionary<string, float> logs) | |||
| { | |||
| } | |||
| public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs) | |||
| { | |||
| var current = get_monitor_value(epoch_logs); | |||
| // If no monitor value exists or still in initial warm-up stage. | |||
| if (current == 0f || epoch < _start_from_epoch) | |||
| return; | |||
| // Restore the weights after first epoch if no progress is ever made. | |||
| if (_restore_best_weights && _best_weights == null) | |||
| { | |||
| _best_weights = _parameters.Model.TrainableWeights; | |||
| } | |||
| _wait += 1; | |||
| if (_is_improvement(current, _best)) | |||
| { | |||
| _best = current; | |||
| _best_epoch = epoch; | |||
| if (_restore_best_weights) | |||
| _best_weights = _parameters.Model.TrainableWeights; | |||
| // Only restart wait if we beat both the baseline and our previous best. | |||
| if (_baseline == 0f || _is_improvement(current, _baseline)) | |||
| _wait = 0; | |||
| } | |||
| // Only check after the first epoch. | |||
| if (_wait >= _paitence && epoch > 0) | |||
| { | |||
| _stopped_epoch = epoch; | |||
| _parameters.Model.set_stopTraining_true(); | |||
| if (_restore_best_weights && _best_weights != null) | |||
| { | |||
| if (_verbose > 0) | |||
| { | |||
| Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); | |||
| } | |||
| } | |||
| // Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet. | |||
| // TODO(Wanglongzhi2001): implement it. | |||
| // _parameters.Model.load_weights(best_weights); | |||
| } | |||
| } | |||
| public void on_train_end() | |||
| { | |||
| if (_stopped_epoch > 0 && _verbose > 0) | |||
| { | |||
| Console.WriteLine($"Epoch {_stopped_epoch + 1}: early stopping"); | |||
| } | |||
| } | |||
| public void on_predict_begin() { } | |||
| public void on_predict_batch_begin(long step) { } | |||
| public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) { } | |||
| public void on_predict_end() { } | |||
| public void on_test_begin() { } | |||
| public void on_test_batch_begin(long step) { } | |||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { } | |||
| float get_monitor_value(Dictionary<string, float> logs) | |||
| { | |||
| logs = logs ?? new Dictionary<string, float>(); | |||
| float monitor_value = logs[_monitor]; | |||
| if (monitor_value == 0f) | |||
| { | |||
| Console.WriteLine($"Early stopping conditioned on metric {_monitor} " + | |||
| $"which is not available. Available metrics are: {string.Join(", ", logs.Keys)}"); | |||
| } | |||
| return monitor_value; | |||
| } | |||
| public bool _is_improvement(float monitor_value, float reference_value) | |||
| { | |||
| bool less_op = (monitor_value - _min_delta) < reference_value; | |||
| bool greater_op = (monitor_value - _min_delta) >= reference_value; | |||
| if (_mode == "min") | |||
| return less_op; | |||
| else if (_mode == "max") | |||
| return greater_op; | |||
| else | |||
| { | |||
| if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) | |||
| { | |||
| return greater_op; | |||
| } | |||
| else | |||
| return less_op; | |||
| } | |||
| } | |||
| } | |||
| @@ -23,6 +23,7 @@ public class History : ICallback | |||
| epochs = new List<int>(); | |||
| history = new Dictionary<string, List<float>>(); | |||
| } | |||
| public void on_train_end() { } | |||
| public void on_epoch_begin(int epoch) | |||
| { | |||
| @@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Callbacks | |||
| _called_in_fit = true; | |||
| _sw = new Stopwatch(); | |||
| } | |||
| public void on_train_end() { } | |||
| public void on_test_begin() | |||
| { | |||
| _sw = new Stopwatch(); | |||
| @@ -19,6 +19,7 @@ namespace Tensorflow.Keras.Engine | |||
| /// <param name="y"></param> | |||
| /// <param name="batch_size"></param> | |||
| /// <param name="epochs"></param> | |||
| /// <param name="callbacks"></param> | |||
| /// <param name="verbose"></param> | |||
| /// <param name="validation_split"></param> | |||
| /// <param name="shuffle"></param> | |||
| @@ -26,6 +27,7 @@ namespace Tensorflow.Keras.Engine | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| bool shuffle = true, | |||
| int initial_epoch = 0, | |||
| @@ -59,7 +61,7 @@ namespace Tensorflow.Keras.Engine | |||
| StepsPerExecution = _steps_per_execution | |||
| }); | |||
| return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||
| train_step_func: train_step_function); | |||
| } | |||
| @@ -67,6 +69,7 @@ namespace Tensorflow.Keras.Engine | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| bool shuffle = true, | |||
| int initial_epoch = 0, | |||
| @@ -107,12 +110,12 @@ namespace Tensorflow.Keras.Engine | |||
| if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || | |||
| data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) | |||
| { | |||
| return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||
| train_step_func: train_step_multi_inputs_function); | |||
| } | |||
| else | |||
| { | |||
| return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||
| train_step_func: train_step_function); | |||
| } | |||
| } | |||
| @@ -122,6 +125,7 @@ namespace Tensorflow.Keras.Engine | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| bool shuffle = true, | |||
| int initial_epoch = 0, | |||
| @@ -143,11 +147,11 @@ namespace Tensorflow.Keras.Engine | |||
| StepsPerExecution = _steps_per_execution | |||
| }); | |||
| return FitInternal(data_handler, epochs, verbose, validation_data: validation_data, | |||
| return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, | |||
| train_step_func: train_step_function); | |||
| } | |||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data, | |||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data, | |||
| Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||
| { | |||
| stop_training = false; | |||
| @@ -159,6 +163,13 @@ namespace Tensorflow.Keras.Engine | |||
| Epochs = epochs, | |||
| Steps = data_handler.Inferredsteps | |||
| }); | |||
| if (callbackList != null) | |||
| { | |||
| foreach(var callback in callbackList) | |||
| callbacks.callbacks.add(callback); | |||
| } | |||
| callbacks.on_train_begin(); | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| @@ -144,5 +144,11 @@ namespace Tensorflow.Keras.Engine | |||
| var children = base._trackable_children(save_type, cache); | |||
| return children; | |||
| } | |||
| void IModel.set_stopTraining_true() | |||
| { | |||
| stop_training = true; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,65 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow.Keras.UnitTest.Helpers; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Keras.Callbacks; | |||
| using Tensorflow.Keras.Engine; | |||
| using System.Collections.Generic; | |||
| using static Tensorflow.KerasApi; | |||
| using Tensorflow.Keras; | |||
| namespace TensorFlowNET.Keras.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class EarltstoppingTest | |||
| { | |||
| [TestMethod] | |||
| // Because loading the weight variable into the model has not yet been implemented, | |||
| // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. | |||
| public void Earltstopping() | |||
| { | |||
| var layers = keras.layers; | |||
| var model = keras.Sequential(new List<ILayer> | |||
| { | |||
| layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)), | |||
| layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), | |||
| layers.MaxPooling2D(), | |||
| layers.Flatten(), | |||
| layers.Dense(128, activation: keras.activations.Relu), | |||
| layers.Dense(10) | |||
| }); | |||
| model.summary(); | |||
| model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), | |||
| loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), | |||
| metrics: new[] { "acc" }); | |||
| var num_epochs = 3; | |||
| var batch_size = 8; | |||
| var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); | |||
| x_train = x_train / 255.0f; | |||
| // define a CallbackParams first, the parameters you pass al least contain Model and Epochs. | |||
| CallbackParams callback_parameters = new CallbackParams | |||
| { | |||
| Model = model, | |||
| Epochs = num_epochs, | |||
| }; | |||
| // define your earlystop | |||
| ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); | |||
| // define a callbcaklist, then add the earlystopping to it. | |||
| var callbacks = new List<ICallback>(); | |||
| callbacks.add(earlystop); | |||
| model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks); | |||
| } | |||
| } | |||
| } | |||