Fix validation_split has no output and add validation_data parameter to model.fittags/v0.100.5-BERT-load
| @@ -15,5 +15,5 @@ public interface ICallback | |||||
| void on_predict_end(); | void on_predict_end(); | ||||
| void on_test_begin(); | void on_test_begin(); | ||||
| void on_test_batch_begin(long step); | void on_test_batch_begin(long step); | ||||
| void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs); | |||||
| void on_test_batch_end(long end_step, Dictionary<string, float> logs); | |||||
| } | } | ||||
| @@ -22,6 +22,7 @@ public interface IModel : ILayer | |||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
| float validation_split = 0f, | float validation_split = 0f, | ||||
| (NDArray val_x, NDArray val_y)? validation_data = null, | |||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| int max_queue_size = 10, | int max_queue_size = 10, | ||||
| @@ -34,6 +35,7 @@ public interface IModel : ILayer | |||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
| float validation_split = 0f, | float validation_split = 0f, | ||||
| (IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| int max_queue_size = 10, | int max_queue_size = 10, | ||||
| @@ -65,7 +67,8 @@ public interface IModel : ILayer | |||||
| int max_queue_size = 10, | int max_queue_size = 10, | ||||
| int workers = 1, | int workers = 1, | ||||
| bool use_multiprocessing = false, | bool use_multiprocessing = false, | ||||
| bool return_dict = false); | |||||
| bool return_dict = false, | |||||
| bool is_val = false); | |||||
| Tensors predict(Tensors x, | Tensors predict(Tensors x, | ||||
| int batch_size = -1, | int batch_size = -1, | ||||
| @@ -69,7 +69,7 @@ public class CallbackList | |||||
| { | { | ||||
| callbacks.ForEach(x => x.on_test_batch_begin(step)); | callbacks.ForEach(x => x.on_test_batch_begin(step)); | ||||
| } | } | ||||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
| public void on_test_batch_end(long end_step, Dictionary<string, float> logs) | |||||
| { | { | ||||
| callbacks.ForEach(x => x.on_test_batch_end(end_step, logs)); | callbacks.ForEach(x => x.on_test_batch_end(end_step, logs)); | ||||
| } | } | ||||
| @@ -121,7 +121,7 @@ public class EarlyStopping: ICallback | |||||
| public void on_predict_end() { } | public void on_predict_end() { } | ||||
| public void on_test_begin() { } | public void on_test_begin() { } | ||||
| public void on_test_batch_begin(long step) { } | public void on_test_batch_begin(long step) { } | ||||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { } | |||||
| public void on_test_batch_end(long end_step, Dictionary<string, float> logs) { } | |||||
| float get_monitor_value(Dictionary<string, float> logs) | float get_monitor_value(Dictionary<string, float> logs) | ||||
| { | { | ||||
| @@ -48,7 +48,7 @@ public class History : ICallback | |||||
| { | { | ||||
| history[log.Key] = new List<float>(); | history[log.Key] = new List<float>(); | ||||
| } | } | ||||
| history[log.Key].Add((float)log.Value); | |||||
| history[log.Key].Add(log.Value); | |||||
| } | } | ||||
| } | } | ||||
| @@ -78,7 +78,7 @@ public class History : ICallback | |||||
| } | } | ||||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
| public void on_test_batch_end(long end_step, Dictionary<string, float> logs) | |||||
| { | { | ||||
| } | } | ||||
| } | } | ||||
| @@ -105,11 +105,11 @@ namespace Tensorflow.Keras.Callbacks | |||||
| { | { | ||||
| _sw.Restart(); | _sw.Restart(); | ||||
| } | } | ||||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
| public void on_test_batch_end(long end_step, Dictionary<string, float> logs) | |||||
| { | { | ||||
| _sw.Stop(); | _sw.Stop(); | ||||
| var elapse = _sw.ElapsedMilliseconds; | var elapse = _sw.ElapsedMilliseconds; | ||||
| var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}")); | |||||
| var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {x.Value:F6}")); | |||||
| Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}"); | Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}"); | ||||
| if (!Console.IsOutputRedirected) | if (!Console.IsOutputRedirected) | ||||
| @@ -26,6 +26,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// <param name="workers"></param> | /// <param name="workers"></param> | ||||
| /// <param name="use_multiprocessing"></param> | /// <param name="use_multiprocessing"></param> | ||||
| /// <param name="return_dict"></param> | /// <param name="return_dict"></param> | ||||
| /// <param name="is_val"></param> | |||||
| public Dictionary<string, float> evaluate(NDArray x, NDArray y, | public Dictionary<string, float> evaluate(NDArray x, NDArray y, | ||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int verbose = 1, | int verbose = 1, | ||||
| @@ -33,7 +34,9 @@ namespace Tensorflow.Keras.Engine | |||||
| int max_queue_size = 10, | int max_queue_size = 10, | ||||
| int workers = 1, | int workers = 1, | ||||
| bool use_multiprocessing = false, | bool use_multiprocessing = false, | ||||
| bool return_dict = false) | |||||
| bool return_dict = false, | |||||
| bool is_val = false | |||||
| ) | |||||
| { | { | ||||
| if (x.dims[0] != y.dims[0]) | if (x.dims[0] != y.dims[0]) | ||||
| { | { | ||||
| @@ -63,11 +66,11 @@ namespace Tensorflow.Keras.Engine | |||||
| }); | }); | ||||
| callbacks.on_test_begin(); | callbacks.on_test_begin(); | ||||
| IEnumerable<(string, Tensor)> logs = null; | |||||
| //Dictionary<string, float>? logs = null; | |||||
| var logs = new Dictionary<string, float>(); | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
| { | { | ||||
| reset_metrics(); | reset_metrics(); | ||||
| callbacks.on_epoch_begin(epoch); | |||||
| // data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
| foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
| @@ -75,19 +78,64 @@ namespace Tensorflow.Keras.Engine | |||||
| callbacks.on_test_batch_begin(step); | callbacks.on_test_batch_begin(step); | ||||
| logs = test_function(data_handler, iterator); | logs = test_function(data_handler, iterator); | ||||
| var end_step = step + data_handler.StepIncrement; | var end_step = step + data_handler.StepIncrement; | ||||
| callbacks.on_test_batch_end(end_step, logs); | |||||
| if (is_val == false) | |||||
| callbacks.on_test_batch_end(end_step, logs); | |||||
| } | } | ||||
| } | } | ||||
| var results = new Dictionary<string, float>(); | var results = new Dictionary<string, float>(); | ||||
| foreach (var log in logs) | foreach (var log in logs) | ||||
| { | { | ||||
| results[log.Item1] = (float)log.Item2; | |||||
| results[log.Key] = log.Value; | |||||
| } | } | ||||
| return results; | return results; | ||||
| } | } | ||||
| public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1) | |||||
| public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false) | |||||
| { | |||||
| var data_handler = new DataHandler(new DataHandlerArgs | |||||
| { | |||||
| X = new Tensors(x), | |||||
| Y = y, | |||||
| Model = this, | |||||
| StepsPerExecution = _steps_per_execution | |||||
| }); | |||||
| var callbacks = new CallbackList(new CallbackParams | |||||
| { | |||||
| Model = this, | |||||
| Verbose = verbose, | |||||
| Steps = data_handler.Inferredsteps | |||||
| }); | |||||
| callbacks.on_test_begin(); | |||||
| Dictionary<string, float> logs = null; | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
| { | |||||
| reset_metrics(); | |||||
| callbacks.on_epoch_begin(epoch); | |||||
| // data_handler.catch_stop_iteration(); | |||||
| foreach (var step in data_handler.steps()) | |||||
| { | |||||
| callbacks.on_test_batch_begin(step); | |||||
| logs = test_step_multi_inputs_function(data_handler, iterator); | |||||
| var end_step = step + data_handler.StepIncrement; | |||||
| if (is_val == false) | |||||
| callbacks.on_test_batch_end(end_step, logs); | |||||
| } | |||||
| } | |||||
| var results = new Dictionary<string, float>(); | |||||
| foreach (var log in logs) | |||||
| { | |||||
| results[log.Key] = log.Value; | |||||
| } | |||||
| return results; | |||||
| } | |||||
| public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false) | |||||
| { | { | ||||
| var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
| { | { | ||||
| @@ -104,7 +152,7 @@ namespace Tensorflow.Keras.Engine | |||||
| }); | }); | ||||
| callbacks.on_test_begin(); | callbacks.on_test_begin(); | ||||
| IEnumerable<(string, Tensor)> logs = null; | |||||
| Dictionary<string, float> logs = null; | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
| { | { | ||||
| reset_metrics(); | reset_metrics(); | ||||
| @@ -113,28 +161,38 @@ namespace Tensorflow.Keras.Engine | |||||
| foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
| { | { | ||||
| // callbacks.on_train_batch_begin(step) | |||||
| callbacks.on_test_batch_begin(step); | |||||
| logs = test_function(data_handler, iterator); | logs = test_function(data_handler, iterator); | ||||
| var end_step = step + data_handler.StepIncrement; | |||||
| if (is_val == false) | |||||
| callbacks.on_test_batch_end(end_step, logs); | |||||
| } | } | ||||
| } | } | ||||
| var results = new Dictionary<string, float>(); | var results = new Dictionary<string, float>(); | ||||
| foreach (var log in logs) | foreach (var log in logs) | ||||
| { | { | ||||
| results[log.Item1] = (float)log.Item2; | |||||
| results[log.Key] = log.Value; | |||||
| } | } | ||||
| return results; | return results; | ||||
| } | } | ||||
| IEnumerable<(string, Tensor)> test_function(DataHandler data_handler, OwnedIterator iterator) | |||||
| Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) | |||||
| { | { | ||||
| var data = iterator.next(); | var data = iterator.next(); | ||||
| var outputs = test_step(data_handler, data[0], data[1]); | var outputs = test_step(data_handler, data[0], data[1]); | ||||
| tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | ||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| List<(string, Tensor)> test_step(DataHandler data_handler, Tensor x, Tensor y) | |||||
| Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) | |||||
| { | |||||
| var data = iterator.next(); | |||||
| var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||||
| var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||||
| tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||||
| return outputs; | |||||
| } | |||||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y) | |||||
| { | { | ||||
| (x, y) = data_handler.DataAdapter.Expand1d(x, y); | (x, y) = data_handler.DataAdapter.Expand1d(x, y); | ||||
| var y_pred = Apply(x, training: false); | var y_pred = Apply(x, training: false); | ||||
| @@ -142,7 +200,7 @@ namespace Tensorflow.Keras.Engine | |||||
| compiled_metrics.update_state(y, y_pred); | compiled_metrics.update_state(y, y_pred); | ||||
| return metrics.Select(x => (x.Name, x.result())).ToList(); | |||||
| return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// <param name="callbacks"></param> | /// <param name="callbacks"></param> | ||||
| /// <param name="verbose"></param> | /// <param name="verbose"></param> | ||||
| /// <param name="validation_split"></param> | /// <param name="validation_split"></param> | ||||
| /// <param name="validation_data"></param> | |||||
| /// <param name="shuffle"></param> | /// <param name="shuffle"></param> | ||||
| public ICallback fit(NDArray x, NDArray y, | public ICallback fit(NDArray x, NDArray y, | ||||
| int batch_size = -1, | int batch_size = -1, | ||||
| @@ -29,6 +30,7 @@ namespace Tensorflow.Keras.Engine | |||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
| float validation_split = 0f, | float validation_split = 0f, | ||||
| (NDArray val_x, NDArray val_y)? validation_data = null, | |||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| int max_queue_size = 10, | int max_queue_size = 10, | ||||
| @@ -40,11 +42,17 @@ namespace Tensorflow.Keras.Engine | |||||
| throw new InvalidArgumentError( | throw new InvalidArgumentError( | ||||
| $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | ||||
| } | } | ||||
| int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||||
| var train_x = x[new Slice(0, train_count)]; | |||||
| var train_y = y[new Slice(0, train_count)]; | |||||
| var val_x = x[new Slice(train_count)]; | |||||
| var val_y = y[new Slice(train_count)]; | |||||
| var train_x = x; | |||||
| var train_y = y; | |||||
| if (validation_split != 0f && validation_data == null) | |||||
| { | |||||
| int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||||
| train_x = x[new Slice(0, train_count)]; | |||||
| train_y = y[new Slice(0, train_count)]; | |||||
| validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]); | |||||
| } | |||||
| var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
| { | { | ||||
| @@ -61,7 +69,7 @@ namespace Tensorflow.Keras.Engine | |||||
| StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
| }); | }); | ||||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, | |||||
| train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
| } | } | ||||
| @@ -71,6 +79,7 @@ namespace Tensorflow.Keras.Engine | |||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
| float validation_split = 0f, | float validation_split = 0f, | ||||
| (IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| int max_queue_size = 10, | int max_queue_size = 10, | ||||
| @@ -85,12 +94,19 @@ namespace Tensorflow.Keras.Engine | |||||
| $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); | $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); | ||||
| } | } | ||||
| } | } | ||||
| int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||||
| var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor); | |||||
| var train_y = y[new Slice(0, train_count)]; | |||||
| var val_x = x.Select(x => x[new Slice(train_count)] as Tensor); | |||||
| var val_y = y[new Slice(train_count)]; | |||||
| var train_x = x; | |||||
| var train_y = y; | |||||
| if (validation_split != 0f && validation_data == null) | |||||
| { | |||||
| int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||||
| train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||||
| train_y = y[new Slice(0, train_count)]; | |||||
| var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||||
| var val_y = y[new Slice(train_count)]; | |||||
| validation_data = (val_x, val_y); | |||||
| } | |||||
| var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
| { | { | ||||
| @@ -110,29 +126,29 @@ namespace Tensorflow.Keras.Engine | |||||
| if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || | if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || | ||||
| data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) | data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) | ||||
| { | { | ||||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, | |||||
| train_step_func: train_step_multi_inputs_function); | train_step_func: train_step_multi_inputs_function); | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||||
| return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, | |||||
| train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
| } | } | ||||
| } | } | ||||
| public History fit(IDatasetV2 dataset, | public History fit(IDatasetV2 dataset, | ||||
| IDatasetV2 validation_data = null, | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int epochs = 1, | int epochs = 1, | ||||
| int verbose = 1, | int verbose = 1, | ||||
| List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
| float validation_split = 0f, | |||||
| IDatasetV2 validation_data = null, | |||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int initial_epoch = 0, | int initial_epoch = 0, | ||||
| int max_queue_size = 10, | int max_queue_size = 10, | ||||
| int workers = 1, | int workers = 1, | ||||
| bool use_multiprocessing = false) | bool use_multiprocessing = false) | ||||
| { | { | ||||
| var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
| { | { | ||||
| Dataset = dataset, | Dataset = dataset, | ||||
| @@ -147,6 +163,7 @@ namespace Tensorflow.Keras.Engine | |||||
| StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
| }); | }); | ||||
| return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, | return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, | ||||
| train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
| } | } | ||||
| @@ -178,11 +195,13 @@ namespace Tensorflow.Keras.Engine | |||||
| callbacks.on_epoch_begin(epoch); | callbacks.on_epoch_begin(epoch); | ||||
| // data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
| var logs = new Dictionary<string, float>(); | var logs = new Dictionary<string, float>(); | ||||
| long End_step = 0; | |||||
| foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
| { | { | ||||
| callbacks.on_train_batch_begin(step); | callbacks.on_train_batch_begin(step); | ||||
| logs = train_step_func(data_handler, iterator); | logs = train_step_func(data_handler, iterator); | ||||
| var end_step = step + data_handler.StepIncrement; | var end_step = step + data_handler.StepIncrement; | ||||
| End_step = end_step; | |||||
| callbacks.on_train_batch_end(end_step, logs); | callbacks.on_train_batch_end(end_step, logs); | ||||
| } | } | ||||
| @@ -193,6 +212,123 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| logs["val_" + log.Key] = log.Value; | logs["val_" + log.Key] = log.Value; | ||||
| } | } | ||||
| callbacks.on_train_batch_end(End_step, logs); | |||||
| } | |||||
| callbacks.on_epoch_end(epoch, logs); | |||||
| GC.Collect(); | |||||
| GC.WaitForPendingFinalizers(); | |||||
| } | |||||
| return callbacks.History; | |||||
| } | |||||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data, | |||||
| Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||||
| { | |||||
| stop_training = false; | |||||
| _train_counter.assign(0); | |||||
| var callbacks = new CallbackList(new CallbackParams | |||||
| { | |||||
| Model = this, | |||||
| Verbose = verbose, | |||||
| Epochs = epochs, | |||||
| Steps = data_handler.Inferredsteps | |||||
| }); | |||||
| if (callbackList != null) | |||||
| { | |||||
| foreach (var callback in callbackList) | |||||
| callbacks.callbacks.add(callback); | |||||
| } | |||||
| callbacks.on_train_begin(); | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
| { | |||||
| reset_metrics(); | |||||
| callbacks.on_epoch_begin(epoch); | |||||
| // data_handler.catch_stop_iteration(); | |||||
| var logs = new Dictionary<string, float>(); | |||||
| long End_step = 0; | |||||
| foreach (var step in data_handler.steps()) | |||||
| { | |||||
| callbacks.on_train_batch_begin(step); | |||||
| logs = train_step_func(data_handler, iterator); | |||||
| var end_step = step + data_handler.StepIncrement; | |||||
| End_step = end_step; | |||||
| callbacks.on_train_batch_end(end_step, logs); | |||||
| } | |||||
| if (validation_data != null) | |||||
| { | |||||
| // Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||||
| // so we need to pass a is_val parameter to stop on_test_batch_end | |||||
| var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); | |||||
| foreach (var log in val_logs) | |||||
| { | |||||
| logs["val_" + log.Key] = log.Value; | |||||
| } | |||||
| // because after evaluate, logs add some new log which we need to print | |||||
| callbacks.on_train_batch_end(End_step, logs); | |||||
| } | |||||
| callbacks.on_epoch_end(epoch, logs); | |||||
| GC.Collect(); | |||||
| GC.WaitForPendingFinalizers(); | |||||
| } | |||||
| return callbacks.History; | |||||
| } | |||||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data, | |||||
| Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||||
| { | |||||
| stop_training = false; | |||||
| _train_counter.assign(0); | |||||
| var callbacks = new CallbackList(new CallbackParams | |||||
| { | |||||
| Model = this, | |||||
| Verbose = verbose, | |||||
| Epochs = epochs, | |||||
| Steps = data_handler.Inferredsteps | |||||
| }); | |||||
| if (callbackList != null) | |||||
| { | |||||
| foreach (var callback in callbackList) | |||||
| callbacks.callbacks.add(callback); | |||||
| } | |||||
| callbacks.on_train_begin(); | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
| { | |||||
| reset_metrics(); | |||||
| callbacks.on_epoch_begin(epoch); | |||||
| // data_handler.catch_stop_iteration(); | |||||
| var logs = new Dictionary<string, float>(); | |||||
| long End_step = 0; | |||||
| foreach (var step in data_handler.steps()) | |||||
| { | |||||
| callbacks.on_train_batch_begin(step); | |||||
| logs = train_step_func(data_handler, iterator); | |||||
| var end_step = step + data_handler.StepIncrement; | |||||
| End_step = end_step; | |||||
| callbacks.on_train_batch_end(end_step, logs); | |||||
| } | |||||
| if (validation_data != null) | |||||
| { | |||||
| var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2); | |||||
| foreach (var log in val_logs) | |||||
| { | |||||
| logs["val_" + log.Key] = log.Value; | |||||
| callbacks.on_train_batch_end(End_step, logs); | |||||
| } | |||||
| } | } | ||||
| callbacks.on_epoch_end(epoch, logs); | callbacks.on_epoch_end(epoch, logs); | ||||