| @@ -19,8 +19,10 @@ public class EarlyStopping: ICallback | |||
| string _monitor; | |||
| string _mode; | |||
| bool _restore_best_weights; | |||
| List<IVariableV1>? _best_weights; | |||
| List<NDArray>? _best_weights; | |||
| CallbackParams _parameters; | |||
| Func<NDArray, NDArray, NDArray> _monitor_op; | |||
| public Dictionary<string, List<float>>? history { get; set; } | |||
| // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model | |||
| public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0, | |||
| @@ -38,17 +40,49 @@ public class EarlyStopping: ICallback | |||
| _min_delta = Math.Abs(min_delta); | |||
| _restore_best_weights = restore_best_weights; | |||
| _mode = mode; | |||
| if (mode != "auto" && mode != "min" && mode != "max") | |||
| if (_mode != "auto" && _mode != "min" && _mode != "max") | |||
| { | |||
| Console.WriteLine($"EarlyStopping mode {_mode} is unknown, fallback to auto mode."); | |||
| _mode = "auto"; | |||
| } | |||
| if (_mode == "min") | |||
| { | |||
| _monitor_op = np.less; | |||
| } | |||
| else if (_mode == "max") | |||
| { | |||
| _monitor_op = np.greater; | |||
| } | |||
| else | |||
| { | |||
| if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) | |||
| { | |||
| _monitor_op = np.greater; | |||
| } | |||
| else | |||
| { | |||
| _monitor_op = np.less; | |||
| } | |||
| } | |||
| if (_monitor_op == np.greater) | |||
| { | |||
| Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode); | |||
| _min_delta *= 1; | |||
| } | |||
| else | |||
| { | |||
| _min_delta *= -1; | |||
| } | |||
| } | |||
| public void on_train_begin() | |||
| { | |||
| _wait = 0; | |||
| _stopped_epoch = 0; | |||
| _best = _monitor_op == np.less ? (float)np.Inf : (float)-np.Inf; | |||
| _best_weights = null; | |||
| _best_epoch = 0; | |||
| _best = (float)np.Inf; | |||
| } | |||
| public void on_epoch_begin(int epoch) | |||
| @@ -74,7 +108,7 @@ public class EarlyStopping: ICallback | |||
| // Restore the weights after first epoch if no progress is ever made. | |||
| if (_restore_best_weights && _best_weights == null) | |||
| { | |||
| _best_weights = _parameters.Model.Weights; | |||
| _best_weights = _parameters.Model.get_weights(); | |||
| } | |||
| _wait += 1; | |||
| @@ -83,7 +117,7 @@ public class EarlyStopping: ICallback | |||
| _best = current; | |||
| _best_epoch = epoch; | |||
| if (_restore_best_weights) | |||
| _best_weights = _parameters.Model.TrainableWeights; | |||
| _best_weights = _parameters.Model.get_weights(); | |||
| // Only restart wait if we beat both the baseline and our previous best. | |||
| if (_baseline == 0f || _is_improvement(current, _baseline)) | |||
| _wait = 0; | |||
| @@ -99,7 +133,7 @@ public class EarlyStopping: ICallback | |||
| { | |||
| Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); | |||
| } | |||
| _parameters.Model.Weights = _best_weights; | |||
| _parameters.Model.set_weights(_best_weights); | |||
| } | |||
| } | |||
| } | |||
| @@ -131,21 +165,7 @@ public class EarlyStopping: ICallback | |||
| } | |||
| public bool _is_improvement(float monitor_value, float reference_value) | |||
| { | |||
| bool less_op = (monitor_value - _min_delta) < reference_value; | |||
| bool greater_op = (monitor_value - _min_delta) >= reference_value; | |||
| if (_mode == "min") | |||
| return less_op; | |||
| else if (_mode == "max") | |||
| return greater_op; | |||
| else | |||
| { | |||
| if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) | |||
| { | |||
| return greater_op; | |||
| } | |||
| else | |||
| return less_op; | |||
| } | |||
| return _monitor_op(monitor_value - _min_delta, reference_value); | |||
| } | |||
| public void on_test_end(Dictionary<string, float> logs) | |||