diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
index 950d2c98..42705dee 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
@@ -15,6 +15,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
public IDataAdapter DataAdapter => _adapter;
IDatasetV2 _dataset;
int _inferred_steps;
+ public int Inferredsteps => _inferred_steps;
int _current_step;
int _step_increment;
bool _insufficient_data;
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
index 3c4960e7..3699f6af 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
@@ -1,4 +1,5 @@
using NumSharp;
+using ShellProgressBar;
using System;
using System.Collections.Generic;
using System.Linq;
@@ -51,22 +52,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});
- stop_training = false;
- _train_counter.assign(0);
- Console.WriteLine($"Training...");
- foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
- {
- // reset_metrics();
- // callbacks.on_epoch_begin(epoch)
- // data_handler.catch_stop_iteration();
- IEnumerable<(string, Tensor)> results = null;
- foreach (var step in data_handler.steps())
- {
- // callbacks.on_train_batch_begin(step)
- results = step_function(iterator);
- }
- Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
- }
+ FitInternal(epochs);
}
public void fit(IDatasetV2 dataset,
@@ -95,21 +81,32 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});
+ FitInternal(epochs);
+ }
+
+ void FitInternal(int epochs)
+ {
stop_training = false;
_train_counter.assign(0);
- Console.WriteLine($"Training...");
+ var options = new ProgressBarOptions
+ {
+ ProgressCharacter = '.',
+ ProgressBarOnBottom = true
+ };
+
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
+ using var pbar = new ProgressBar(data_handler.Inferredsteps, "Training...", options);
// reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
- IEnumerable<(string, Tensor)> results = null;
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
- results = step_function(iterator);
+ var results = step_function(iterator);
+ var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
+ pbar.Tick($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]");
}
- Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}
}
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index b97fdb8d..e4c42061 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -47,6 +47,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
+