| @@ -19,8 +19,10 @@ public class EarlyStopping: ICallback | |||||
| string _monitor; | string _monitor; | ||||
| string _mode; | string _mode; | ||||
| bool _restore_best_weights; | bool _restore_best_weights; | ||||
| List<IVariableV1>? _best_weights; | |||||
| List<NDArray>? _best_weights; | |||||
| CallbackParams _parameters; | CallbackParams _parameters; | ||||
| Func<NDArray, NDArray, NDArray> _monitor_op; | |||||
| public Dictionary<string, List<float>>? history { get; set; } | public Dictionary<string, List<float>>? history { get; set; } | ||||
| // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model | // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model | ||||
| public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0, | public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0, | ||||
| @@ -38,17 +40,49 @@ public class EarlyStopping: ICallback | |||||
| _min_delta = Math.Abs(min_delta); | _min_delta = Math.Abs(min_delta); | ||||
| _restore_best_weights = restore_best_weights; | _restore_best_weights = restore_best_weights; | ||||
| _mode = mode; | _mode = mode; | ||||
| if (mode != "auto" && mode != "min" && mode != "max") | |||||
| if (_mode != "auto" && _mode != "min" && _mode != "max") | |||||
| { | |||||
| Console.WriteLine($"EarlyStopping mode {_mode} is unknown, fallback to auto mode."); | |||||
| _mode = "auto"; | |||||
| } | |||||
| if (_mode == "min") | |||||
| { | |||||
| _monitor_op = np.less; | |||||
| } | |||||
| else if (_mode == "max") | |||||
| { | |||||
| _monitor_op = np.greater; | |||||
| } | |||||
| else | |||||
| { | |||||
| if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) | |||||
| { | |||||
| _monitor_op = np.greater; | |||||
| } | |||||
| else | |||||
| { | |||||
| _monitor_op = np.less; | |||||
| } | |||||
| } | |||||
| if (_monitor_op == np.greater) | |||||
| { | { | ||||
| Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode); | |||||
| _min_delta *= 1; | |||||
| } | |||||
| else | |||||
| { | |||||
| _min_delta *= -1; | |||||
| } | } | ||||
| } | } | ||||
| public void on_train_begin() | public void on_train_begin() | ||||
| { | { | ||||
| _wait = 0; | _wait = 0; | ||||
| _stopped_epoch = 0; | _stopped_epoch = 0; | ||||
| _best = _monitor_op == np.less ? (float)np.Inf : (float)-np.Inf; | |||||
| _best_weights = null; | |||||
| _best_epoch = 0; | _best_epoch = 0; | ||||
| _best = (float)np.Inf; | |||||
| } | } | ||||
| public void on_epoch_begin(int epoch) | public void on_epoch_begin(int epoch) | ||||
| @@ -74,7 +108,7 @@ public class EarlyStopping: ICallback | |||||
| // Restore the weights after first epoch if no progress is ever made. | // Restore the weights after first epoch if no progress is ever made. | ||||
| if (_restore_best_weights && _best_weights == null) | if (_restore_best_weights && _best_weights == null) | ||||
| { | { | ||||
| _best_weights = _parameters.Model.Weights; | |||||
| _best_weights = _parameters.Model.get_weights(); | |||||
| } | } | ||||
| _wait += 1; | _wait += 1; | ||||
| @@ -83,7 +117,7 @@ public class EarlyStopping: ICallback | |||||
| _best = current; | _best = current; | ||||
| _best_epoch = epoch; | _best_epoch = epoch; | ||||
| if (_restore_best_weights) | if (_restore_best_weights) | ||||
| _best_weights = _parameters.Model.TrainableWeights; | |||||
| _best_weights = _parameters.Model.get_weights(); | |||||
| // Only restart wait if we beat both the baseline and our previous best. | // Only restart wait if we beat both the baseline and our previous best. | ||||
| if (_baseline == 0f || _is_improvement(current, _baseline)) | if (_baseline == 0f || _is_improvement(current, _baseline)) | ||||
| _wait = 0; | _wait = 0; | ||||
| @@ -99,7 +133,7 @@ public class EarlyStopping: ICallback | |||||
| { | { | ||||
| Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); | Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); | ||||
| } | } | ||||
| _parameters.Model.Weights = _best_weights; | |||||
| _parameters.Model.set_weights(_best_weights); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -131,21 +165,7 @@ public class EarlyStopping: ICallback | |||||
| } | } | ||||
| public bool _is_improvement(float monitor_value, float reference_value) | public bool _is_improvement(float monitor_value, float reference_value) | ||||
| { | { | ||||
| bool less_op = (monitor_value - _min_delta) < reference_value; | |||||
| bool greater_op = (monitor_value - _min_delta) >= reference_value; | |||||
| if (_mode == "min") | |||||
| return less_op; | |||||
| else if (_mode == "max") | |||||
| return greater_op; | |||||
| else | |||||
| { | |||||
| if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) | |||||
| { | |||||
| return greater_op; | |||||
| } | |||||
| else | |||||
| return less_op; | |||||
| } | |||||
| return _monitor_op(monitor_value - _min_delta, reference_value); | |||||
| } | } | ||||
| public void on_test_end(Dictionary<string, float> logs) | public void on_test_end(Dictionary<string, float> logs) | ||||