Browse Source

Add ShellProgressBar for model fitting.

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
cdc0c2ecac
3 changed files with 19 additions and 20 deletions
  1. +1
    -0
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  2. +17
    -20
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  3. +1
    -0
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

+ 1
- 0
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -15,6 +15,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
public IDataAdapter DataAdapter => _adapter;
IDatasetV2 _dataset;
int _inferred_steps;
public int Inferredsteps => _inferred_steps;
int _current_step;
int _step_increment;
bool _insufficient_data;


+ 17
- 20
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -1,4 +1,5 @@
using NumSharp;
using ShellProgressBar;
using System;
using System.Collections.Generic;
using System.Linq;
@@ -51,22 +52,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

stop_training = false;
_train_counter.assign(0);
Console.WriteLine($"Training...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
// reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
results = step_function(iterator);
}
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
FitInternal(epochs);
}

public void fit(IDatasetV2 dataset,
@@ -95,21 +81,32 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

FitInternal(epochs);
}

void FitInternal(int epochs)
{
stop_training = false;
_train_counter.assign(0);
Console.WriteLine($"Training...");
var options = new ProgressBarOptions
{
ProgressCharacter = '.',
ProgressBarOnBottom = true
};

foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
using var pbar = new ProgressBar(data_handler.Inferredsteps, "Training...", options);
// reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
results = step_function(iterator);
var results = step_function(iterator);
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
pbar.Tick($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]");
}
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}
}


+ 1
- 0
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -47,6 +47,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="SharpZipLib" Version="1.3.1" />
<PackageReference Include="ShellProgressBar" Version="5.0.0" />
</ItemGroup>

<ItemGroup>


Loading…
Cancel
Save