diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs index 950d2c98..42705dee 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -15,6 +15,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters public IDataAdapter DataAdapter => _adapter; IDatasetV2 _dataset; int _inferred_steps; + public int Inferredsteps => _inferred_steps; int _current_step; int _step_increment; bool _insufficient_data; diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 3c4960e7..3699f6af 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -1,4 +1,5 @@ using NumSharp; +using ShellProgressBar; using System; using System.Collections.Generic; using System.Linq; @@ -51,22 +52,7 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - stop_training = false; - _train_counter.assign(0); - Console.WriteLine($"Training..."); - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) - { - // reset_metrics(); - // callbacks.on_epoch_begin(epoch) - // data_handler.catch_stop_iteration(); - IEnumerable<(string, Tensor)> results = null; - foreach (var step in data_handler.steps()) - { - // callbacks.on_train_batch_begin(step) - results = step_function(iterator); - } - Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); - } + FitInternal(epochs); } public void fit(IDatasetV2 dataset, @@ -95,21 +81,32 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); + FitInternal(epochs); + } + + void FitInternal(int epochs) + { stop_training = false; _train_counter.assign(0); - Console.WriteLine($"Training..."); + var options = new ProgressBarOptions + { + ProgressCharacter = '.', + ProgressBarOnBottom = true + }; + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { + using var pbar = new ProgressBar(data_handler.Inferredsteps, "Training...", options); // reset_metrics(); // callbacks.on_epoch_begin(epoch) // data_handler.catch_stop_iteration(); - IEnumerable<(string, Tensor)> results = null; foreach (var step in data_handler.steps()) { // callbacks.on_train_batch_begin(step) - results = step_function(iterator); + var results = step_function(iterator); + var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); + pbar.Tick($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); } - Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); } } } diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index b97fdb8d..e4c42061 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -47,6 +47,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac +