| @@ -4,6 +4,7 @@ public interface ICallback | |||||
| { | { | ||||
| Dictionary<string, List<float>> history { get; set; } | Dictionary<string, List<float>> history { get; set; } | ||||
| void on_train_begin(); | void on_train_begin(); | ||||
| void on_train_end(); | |||||
| void on_epoch_begin(int epoch); | void on_epoch_begin(int epoch); | ||||
| void on_train_batch_begin(long step); | void on_train_batch_begin(long step); | ||||
| void on_train_batch_end(long end_step, Dictionary<string, float> logs); | void on_train_batch_end(long end_step, Dictionary<string, float> logs); | ||||
| @@ -17,6 +17,7 @@ public interface IModel : ILayer | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int epochs = 1, | int epochs = 1, | ||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | |||||
| float validation_split = 0f, | float validation_split = 0f, | ||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| @@ -28,6 +29,7 @@ public interface IModel : ILayer | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int epochs = 1, | int epochs = 1, | ||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | |||||
| float validation_split = 0f, | float validation_split = 0f, | ||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| @@ -73,4 +75,6 @@ public interface IModel : ILayer | |||||
| void summary(int line_length = -1, float[] positions = null); | void summary(int line_length = -1, float[] positions = null); | ||||
| IKerasConfig get_config(); | IKerasConfig get_config(); | ||||
| void set_stopTraining_true(); | |||||
| } | } | ||||
| @@ -7,7 +7,8 @@ namespace Tensorflow.Keras.Callbacks; | |||||
| public class CallbackList | public class CallbackList | ||||
| { | { | ||||
| List<ICallback> callbacks = new List<ICallback>(); | |||||
| // 改成public使得新定义的callback可以加入到callbacks里 | |||||
| public List<ICallback> callbacks = new List<ICallback>(); | |||||
| public History History => callbacks[0] as History; | public History History => callbacks[0] as History; | ||||
| public CallbackList(CallbackParams parameters) | public CallbackList(CallbackParams parameters) | ||||
| @@ -66,7 +67,7 @@ public class CallbackList | |||||
| public void on_test_batch_begin(long step) | public void on_test_batch_begin(long step) | ||||
| { | { | ||||
| callbacks.ForEach(x => x.on_train_batch_begin(step)); | |||||
| callbacks.ForEach(x => x.on_test_batch_begin(step)); | |||||
| } | } | ||||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | ||||
| { | { | ||||
| @@ -0,0 +1,155 @@ | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Callbacks; | |||||
| /// <summary> | |||||
| /// Stop training when a monitored metric has stopped improving. | |||||
| /// </summary> | |||||
| /// <param name="parameters"></param> | |||||
| /// <param name="monitor"></param> | |||||
| public class EarlyStopping: ICallback | |||||
| { | |||||
| int _paitence; | |||||
| int _min_delta; | |||||
| int _verbose; | |||||
| int _stopped_epoch; | |||||
| int _wait; | |||||
| int _best_epoch; | |||||
| int _start_from_epoch; | |||||
| float _best; | |||||
| float _baseline; | |||||
| string _monitor; | |||||
| string _mode; | |||||
| bool _restore_best_weights; | |||||
| List<IVariableV1>? _best_weights; | |||||
| CallbackParams _parameters; | |||||
| public Dictionary<string, List<float>>? history { get; set; } | |||||
| // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model | |||||
| public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", int min_delta = 0, int patience = 0, | |||||
| int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false, | |||||
| int start_from_epoch = 0) | |||||
| { | |||||
| _parameters = parameters; | |||||
| _stopped_epoch = 0; | |||||
| _wait = 0; | |||||
| _monitor = monitor; | |||||
| _paitence = patience; | |||||
| _verbose = verbose; | |||||
| _baseline = baseline; | |||||
| _start_from_epoch = start_from_epoch; | |||||
| _min_delta = Math.Abs(min_delta); | |||||
| _restore_best_weights = restore_best_weights; | |||||
| _mode = mode; | |||||
| if (mode != "auto" && mode != "min" && mode != "max") | |||||
| { | |||||
| Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode); | |||||
| } | |||||
| } | |||||
| public void on_train_begin() | |||||
| { | |||||
| _wait = 0; | |||||
| _stopped_epoch = 0; | |||||
| _best_epoch = 0; | |||||
| _best = (float)np.Inf; | |||||
| } | |||||
| public void on_epoch_begin(int epoch) | |||||
| { | |||||
| } | |||||
| public void on_train_batch_begin(long step) | |||||
| { | |||||
| } | |||||
| public void on_train_batch_end(long end_step, Dictionary<string, float> logs) | |||||
| { | |||||
| } | |||||
| public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs) | |||||
| { | |||||
| var current = get_monitor_value(epoch_logs); | |||||
| // If no monitor value exists or still in initial warm-up stage. | |||||
| if (current == 0f || epoch < _start_from_epoch) | |||||
| return; | |||||
| // Restore the weights after first epoch if no progress is ever made. | |||||
| if (_restore_best_weights && _best_weights == null) | |||||
| { | |||||
| _best_weights = _parameters.Model.TrainableWeights; | |||||
| } | |||||
| _wait += 1; | |||||
| if (_is_improvement(current, _best)) | |||||
| { | |||||
| _best = current; | |||||
| _best_epoch = epoch; | |||||
| if (_restore_best_weights) | |||||
| _best_weights = _parameters.Model.TrainableWeights; | |||||
| // Only restart wait if we beat both the baseline and our previous best. | |||||
| if (_baseline == 0f || _is_improvement(current, _baseline)) | |||||
| _wait = 0; | |||||
| } | |||||
| // Only check after the first epoch. | |||||
| if (_wait >= _paitence && epoch > 0) | |||||
| { | |||||
| _stopped_epoch = epoch; | |||||
| _parameters.Model.set_stopTraining_true(); | |||||
| if (_restore_best_weights && _best_weights != null) | |||||
| { | |||||
| if (_verbose > 0) | |||||
| { | |||||
| Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); | |||||
| } | |||||
| } | |||||
| // Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet. | |||||
| // TODO(Wanglongzhi2001): implement it. | |||||
| // _parameters.Model.load_weights(best_weights); | |||||
| } | |||||
| } | |||||
| public void on_train_end() | |||||
| { | |||||
| if (_stopped_epoch > 0 && _verbose > 0) | |||||
| { | |||||
| Console.WriteLine($"Epoch {_stopped_epoch + 1}: early stopping"); | |||||
| } | |||||
| } | |||||
| public void on_predict_begin() { } | |||||
| public void on_predict_batch_begin(long step) { } | |||||
| public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) { } | |||||
| public void on_predict_end() { } | |||||
| public void on_test_begin() { } | |||||
| public void on_test_batch_begin(long step) { } | |||||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { } | |||||
| float get_monitor_value(Dictionary<string, float> logs) | |||||
| { | |||||
| logs = logs ?? new Dictionary<string, float>(); | |||||
| float monitor_value = logs[_monitor]; | |||||
| if (monitor_value == 0f) | |||||
| { | |||||
| Console.WriteLine($"Early stopping conditioned on metric {_monitor} " + | |||||
| $"which is not available. Available metrics are: {string.Join(", ", logs.Keys)}"); | |||||
| } | |||||
| return monitor_value; | |||||
| } | |||||
| public bool _is_improvement(float monitor_value, float reference_value) | |||||
| { | |||||
| bool less_op = (monitor_value - _min_delta) < reference_value; | |||||
| bool greater_op = (monitor_value - _min_delta) >= reference_value; | |||||
| if (_mode == "min") | |||||
| return less_op; | |||||
| else if (_mode == "max") | |||||
| return greater_op; | |||||
| else | |||||
| { | |||||
| if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) | |||||
| { | |||||
| return greater_op; | |||||
| } | |||||
| else | |||||
| return less_op; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -23,6 +23,7 @@ public class History : ICallback | |||||
| epochs = new List<int>(); | epochs = new List<int>(); | ||||
| history = new Dictionary<string, List<float>>(); | history = new Dictionary<string, List<float>>(); | ||||
| } | } | ||||
| public void on_train_end() { } | |||||
| public void on_epoch_begin(int epoch) | public void on_epoch_begin(int epoch) | ||||
| { | { | ||||
| @@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Callbacks | |||||
| _called_in_fit = true; | _called_in_fit = true; | ||||
| _sw = new Stopwatch(); | _sw = new Stopwatch(); | ||||
| } | } | ||||
| public void on_train_end() { } | |||||
| public void on_test_begin() | public void on_test_begin() | ||||
| { | { | ||||
| _sw = new Stopwatch(); | _sw = new Stopwatch(); | ||||
| @@ -19,6 +19,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// <param name="y"></param> | /// <param name="y"></param> | ||||
| /// <param name="batch_size"></param> | /// <param name="batch_size"></param> | ||||
| /// <param name="epochs"></param> | /// <param name="epochs"></param> | ||||
| /// <param name="callbacks"></param> | |||||
| /// <param name="verbose"></param> | /// <param name="verbose"></param> | ||||
| /// <param name="validation_split"></param> | /// <param name="validation_split"></param> | ||||
| /// <param name="shuffle"></param> | /// <param name="shuffle"></param> | ||||
| @@ -26,6 +27,7 @@ namespace Tensorflow.Keras.Engine | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int epochs = 1, | int epochs = 1, | ||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | |||||
| float validation_split = 0f, | float validation_split = 0f, | ||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| @@ -59,7 +61,7 @@ namespace Tensorflow.Keras.Engine | |||||
| StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
| }); | }); | ||||
| return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||||
| train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
| } | } | ||||
| @@ -67,6 +69,7 @@ namespace Tensorflow.Keras.Engine | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int epochs = 1, | int epochs = 1, | ||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | |||||
| float validation_split = 0f, | float validation_split = 0f, | ||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| @@ -107,12 +110,12 @@ namespace Tensorflow.Keras.Engine | |||||
| if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || | if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || | ||||
| data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) | data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) | ||||
| { | { | ||||
| return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||||
| train_step_func: train_step_multi_inputs_function); | train_step_func: train_step_multi_inputs_function); | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||||
| train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
| } | } | ||||
| } | } | ||||
| @@ -122,6 +125,7 @@ namespace Tensorflow.Keras.Engine | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int epochs = 1, | int epochs = 1, | ||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | |||||
| float validation_split = 0f, | float validation_split = 0f, | ||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| @@ -143,11 +147,11 @@ namespace Tensorflow.Keras.Engine | |||||
| StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
| }); | }); | ||||
| return FitInternal(data_handler, epochs, verbose, validation_data: validation_data, | |||||
| return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, | |||||
| train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
| } | } | ||||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data, | |||||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data, | |||||
| Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | ||||
| { | { | ||||
| stop_training = false; | stop_training = false; | ||||
| @@ -159,6 +163,13 @@ namespace Tensorflow.Keras.Engine | |||||
| Epochs = epochs, | Epochs = epochs, | ||||
| Steps = data_handler.Inferredsteps | Steps = data_handler.Inferredsteps | ||||
| }); | }); | ||||
| if (callbackList != null) | |||||
| { | |||||
| foreach(var callback in callbackList) | |||||
| callbacks.callbacks.add(callback); | |||||
| } | |||||
| callbacks.on_train_begin(); | callbacks.on_train_begin(); | ||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
| @@ -144,5 +144,11 @@ namespace Tensorflow.Keras.Engine | |||||
| var children = base._trackable_children(save_type, cache); | var children = base._trackable_children(save_type, cache); | ||||
| return children; | return children; | ||||
| } | } | ||||
| void IModel.set_stopTraining_true() | |||||
| { | |||||
| stop_training = true; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,65 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow.Keras.UnitTest.Helpers; | |||||
| using static Tensorflow.Binding; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Optimizers; | |||||
| using Tensorflow.Keras.Callbacks; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using System.Collections.Generic; | |||||
| using static Tensorflow.KerasApi; | |||||
| using Tensorflow.Keras; | |||||
| namespace TensorFlowNET.Keras.UnitTest | |||||
| { | |||||
| [TestClass] | |||||
| public class EarltstoppingTest | |||||
| { | |||||
| [TestMethod] | |||||
| // Because loading the weight variable into the model has not yet been implemented, | |||||
| // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. | |||||
| public void Earltstopping() | |||||
| { | |||||
| var layers = keras.layers; | |||||
| var model = keras.Sequential(new List<ILayer> | |||||
| { | |||||
| layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)), | |||||
| layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), | |||||
| layers.MaxPooling2D(), | |||||
| layers.Flatten(), | |||||
| layers.Dense(128, activation: keras.activations.Relu), | |||||
| layers.Dense(10) | |||||
| }); | |||||
| model.summary(); | |||||
| model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), | |||||
| loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), | |||||
| metrics: new[] { "acc" }); | |||||
| var num_epochs = 3; | |||||
| var batch_size = 8; | |||||
| var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); | |||||
| x_train = x_train / 255.0f; | |||||
| // define a CallbackParams first, the parameters you pass al least contain Model and Epochs. | |||||
| CallbackParams callback_parameters = new CallbackParams | |||||
| { | |||||
| Model = model, | |||||
| Epochs = num_epochs, | |||||
| }; | |||||
| // define your earlystop | |||||
| ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); | |||||
| // define a callbcaklist, then add the earlystopping to it. | |||||
| var callbacks = new List<ICallback>(); | |||||
| callbacks.add(earlystop); | |||||
| model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks); | |||||
| } | |||||
| } | |||||
| } | |||||