Browse Source

fixed issues about displaying progress in console when download/unzip resources

tags/v0.12
Kerry Jiang Oceania2018 6 years ago
parent
commit
9526ddcc3d
8 changed files with 57 additions and 34 deletions
  1. +11
    -10
      src/TensorFlowHub/MnistModelLoader.cs
  2. +39
    -18
      src/TensorFlowHub/Utils.cs
  3. +2
    -1
      test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs
  4. +1
    -1
      test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs
  5. +1
    -1
      test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs
  6. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs
  7. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs
  8. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs

+ 11
- 10
src/TensorFlowHub/MnistModelLoader.cs View File

@@ -15,14 +15,15 @@ namespace Tensorflow.Hub
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";

public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null)
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false)
{
var loader = new MnistModelLoader();

var setting = new ModelLoadSetting
{
TrainDir = trainDir,
OneHot = oneHot
OneHot = oneHot,
ShowProgressInConsole = showProgressInConsole
};

if (trainSize.HasValue)
@@ -48,37 +49,37 @@ namespace Tensorflow.Hub
sourceUrl = DEFAULT_SOURCE_URL;

// load train images
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES)
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir)
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize);

// load train labels
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS)
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir)
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize);

// load test images
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES)
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir)
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize);

// load test labels
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS)
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir)
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize);


+ 39
- 18
src/TensorFlowHub/Utils.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Hub
await modelLoader.DownloadAsync(url, dir, fileName);
}

public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName)
public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false)
where TDataSet : IDataSet
{
if (!Path.IsPathRooted(dirSaveTo))
@@ -27,19 +27,31 @@ namespace Tensorflow.Hub

var fileSaveTo = Path.Combine(dirSaveTo, fileName);

if (showProgressInConsole)
{
Console.WriteLine($"Downloading {fileName}");
}

if (File.Exists(fileSaveTo))
{
if (showProgressInConsole)
{
Console.WriteLine($"The file {fileName} already exists");
}

return;
}
Directory.CreateDirectory(dirSaveTo);
using (var wc = new WebClient())
{
await wc.DownloadFileTaskAsync(url, fileSaveTo);
await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false);
}

}

public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo)
public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false)
where TDataSet : IDataSet
{
if (!Path.IsPathRooted(saveTo))
@@ -50,67 +62,76 @@ namespace Tensorflow.Hub
if (!Path.IsPathRooted(zipFile))
zipFile = Path.Combine(AppContext.BaseDirectory, zipFile);

var destFilePath = Path.Combine(saveTo, Path.GetFileNameWithoutExtension(zipFile));
var destFileName = Path.GetFileNameWithoutExtension(zipFile);
var destFilePath = Path.Combine(saveTo, destFileName);

if (showProgressInConsole)
Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}");

if (File.Exists(destFilePath))
File.Delete(destFilePath);
{
if (showProgressInConsole)
Console.WriteLine($"The file {destFileName} already exists");
}

using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress))
{
using (var destStream = File.Create(destFilePath))
{
await unzipStream.CopyToAsync(destStream);
await destStream.FlushAsync();
await unzipStream.CopyToAsync(destStream).ConfigureAwait(false);
await destStream.FlushAsync().ConfigureAwait(false);
destStream.Close();
}

unzipStream.Close();
}
}

public static async Task ShowProgressInConsole(this Task task)
{
await ShowProgressInConsole(task, true);
}
}

public static async Task ShowProgressInConsole(this Task task, bool enable)
{
if (!enable)
{
await task;
return;
}

var cts = new CancellationTokenSource();

var showProgressTask = ShowProgressInConsole(cts);

try
{
{
await task;
}
finally
{
cts.Cancel();
cts.Cancel();
}

await showProgressTask;
Console.WriteLine("Done.");
}

private static async Task ShowProgressInConsole(CancellationTokenSource cts)
{
var cols = 0;

await Task.Delay(1000);

while (!cts.IsCancellationRequested)
{
await Task.Delay(1000);
Console.Write(".");
cols++;

if (cols >= 50)
if (cols % 50 == 0)
{
cols = 0;
Console.WriteLine();
}
}

Console.WriteLine();
if (cols > 0)
Console.WriteLine();
}
}
}

+ 2
- 1
test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs View File

@@ -70,7 +70,8 @@ namespace TensorFlowNET.Examples
OneHot = true,
TrainSize = train_size,
ValidationSize = validation_size,
TestSize = test_size
TestSize = test_size,
ShowProgressInConsole = true
};

mnist = loader.LoadAsync(setting).Result;


+ 1
- 1
test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs View File

@@ -124,7 +124,7 @@ namespace TensorFlowNET.Examples

public void PrepareData()
{
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result;
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size, showProgressInConsole: true).Result;
}

public void SaveModel(Session sess)


+ 1
- 1
test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs View File

@@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples

public void PrepareData()
{
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result;
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize, showProgressInConsole: true).Result;
// In this example, we limit mnist data
(Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs View File

@@ -310,7 +310,7 @@ namespace TensorFlowNET.Examples
public void PrepareData()
{
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result;
(x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels);
(x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels);
(x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels);


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs View File

@@ -121,7 +121,7 @@ namespace TensorFlowNET.Examples
public void PrepareData()
{
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result;
}

public void Train(Session sess)


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs View File

@@ -143,7 +143,7 @@ namespace TensorFlowNET.Examples

public void PrepareData()
{
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result;
(x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels);
(x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels);
(x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels);


Loading…
Cancel
Save