using Tensorflow.Keras.Engine;
namespace Tensorflow.Keras.Callbacks;
///
/// Stop training when a monitored metric has stopped improving.
///
///
///
public class EarlyStopping: ICallback
{
int _paitence;
float _min_delta;
int _verbose;
int _stopped_epoch;
int _wait;
int _best_epoch;
int _start_from_epoch;
float _best;
float _baseline;
string _monitor;
string _mode;
bool _restore_best_weights;
List? _best_weights;
CallbackParams _parameters;
public Dictionary>? history { get; set; }
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model
public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0,
int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false,
int start_from_epoch = 0)
{
_parameters = parameters;
_stopped_epoch = 0;
_wait = 0;
_monitor = monitor;
_paitence = patience;
_verbose = verbose;
_baseline = baseline;
_start_from_epoch = start_from_epoch;
_min_delta = Math.Abs(min_delta);
_restore_best_weights = restore_best_weights;
_mode = mode;
if (mode != "auto" && mode != "min" && mode != "max")
{
Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode);
}
}
public void on_train_begin()
{
_wait = 0;
_stopped_epoch = 0;
_best_epoch = 0;
_best = (float)np.Inf;
}
public void on_epoch_begin(int epoch)
{
}
public void on_train_batch_begin(long step)
{
}
public void on_train_batch_end(long end_step, Dictionary logs)
{
}
public void on_epoch_end(int epoch, Dictionary epoch_logs)
{
var current = get_monitor_value(epoch_logs);
// If no monitor value exists or still in initial warm-up stage.
if (current == 0f || epoch < _start_from_epoch)
return;
// Restore the weights after first epoch if no progress is ever made.
if (_restore_best_weights && _best_weights == null)
{
_best_weights = _parameters.Model.Weights;
}
_wait += 1;
if (_is_improvement(current, _best))
{
_best = current;
_best_epoch = epoch;
if (_restore_best_weights)
_best_weights = _parameters.Model.TrainableWeights;
// Only restart wait if we beat both the baseline and our previous best.
if (_baseline == 0f || _is_improvement(current, _baseline))
_wait = 0;
}
// Only check after the first epoch.
if (_wait >= _paitence && epoch > 0)
{
_stopped_epoch = epoch;
_parameters.Model.Stop_training = true;
if (_restore_best_weights && _best_weights != null)
{
if (_verbose > 0)
{
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
}
_parameters.Model.Weights = _best_weights;
}
}
}
public void on_train_end()
{
if (_stopped_epoch > 0 && _verbose > 0)
{
Console.WriteLine($"Epoch {_stopped_epoch + 1}: early stopping");
}
}
public void on_predict_begin() { }
public void on_predict_batch_begin(long step) { }
public void on_predict_batch_end(long end_step, Dictionary logs) { }
public void on_predict_end() { }
public void on_test_begin() { }
public void on_test_batch_begin(long step) { }
public void on_test_batch_end(long end_step, Dictionary logs) { }
float get_monitor_value(Dictionary logs)
{
logs = logs ?? new Dictionary();
float monitor_value = logs[_monitor];
if (monitor_value == 0f)
{
Console.WriteLine($"Early stopping conditioned on metric {_monitor} " +
$"which is not available. Available metrics are: {string.Join(", ", logs.Keys)}");
}
return monitor_value;
}
public bool _is_improvement(float monitor_value, float reference_value)
{
bool less_op = (monitor_value - _min_delta) < reference_value;
bool greater_op = (monitor_value - _min_delta) >= reference_value;
if (_mode == "min")
return less_op;
else if (_mode == "max")
return greater_op;
else
{
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
{
return greater_op;
}
else
return less_op;
}
}
public void on_test_end(Dictionary logs)
{
}
}