Browse Source

fixed issues about displaying progress in console when download/unzip resources

tags/v0.12
Kerry Jiang 6 years ago
parent
commit
a36d0602e2
8 changed files with 57 additions and 34 deletions
  1. +11
    -10
      src/TensorFlowHub/MnistModelLoader.cs
  2. +39
    -18
      src/TensorFlowHub/Utils.cs
  3. +2
    -1
      test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs
  4. +1
    -1
      test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs
  5. +1
    -1
      test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs
  6. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs
  7. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs
  8. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs

+ 11
- 10
src/TensorFlowHub/MnistModelLoader.cs View File

@@ -15,14 +15,15 @@ namespace Tensorflow.Hub
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";


public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null)
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false)
{ {
var loader = new MnistModelLoader(); var loader = new MnistModelLoader();


var setting = new ModelLoadSetting var setting = new ModelLoadSetting
{ {
TrainDir = trainDir, TrainDir = trainDir,
OneHot = oneHot
OneHot = oneHot,
ShowProgressInConsole = showProgressInConsole
}; };


if (trainSize.HasValue) if (trainSize.HasValue)
@@ -48,37 +49,37 @@ namespace Tensorflow.Hub
sourceUrl = DEFAULT_SOURCE_URL; sourceUrl = DEFAULT_SOURCE_URL;


// load train images // load train images
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES)
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole); .ShowProgressInConsole(setting.ShowProgressInConsole);


await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir)
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole); .ShowProgressInConsole(setting.ShowProgressInConsole);


var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize);


// load train labels // load train labels
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS)
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole); .ShowProgressInConsole(setting.ShowProgressInConsole);


await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir)
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole); .ShowProgressInConsole(setting.ShowProgressInConsole);


var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize);


// load test images // load test images
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES)
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole); .ShowProgressInConsole(setting.ShowProgressInConsole);


await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir)
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole); .ShowProgressInConsole(setting.ShowProgressInConsole);


var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize);


// load test labels // load test labels
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS)
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole); .ShowProgressInConsole(setting.ShowProgressInConsole);


await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir)
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole); .ShowProgressInConsole(setting.ShowProgressInConsole);


var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize);


+ 39
- 18
src/TensorFlowHub/Utils.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Hub
await modelLoader.DownloadAsync(url, dir, fileName); await modelLoader.DownloadAsync(url, dir, fileName);
} }


public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName)
public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false)
where TDataSet : IDataSet where TDataSet : IDataSet
{ {
if (!Path.IsPathRooted(dirSaveTo)) if (!Path.IsPathRooted(dirSaveTo))
@@ -27,18 +27,30 @@ namespace Tensorflow.Hub


var fileSaveTo = Path.Combine(dirSaveTo, fileName); var fileSaveTo = Path.Combine(dirSaveTo, fileName);


if (showProgressInConsole)
{
Console.WriteLine($"Downloading {fileName}");
}

if (File.Exists(fileSaveTo)) if (File.Exists(fileSaveTo))
{
if (showProgressInConsole)
{
Console.WriteLine($"The file {fileName} already exists");
}

return; return;
}
Directory.CreateDirectory(dirSaveTo); Directory.CreateDirectory(dirSaveTo);
using (var wc = new WebClient()) using (var wc = new WebClient())
{ {
await wc.DownloadFileTaskAsync(url, fileSaveTo);
await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false);
} }
} }


public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo)
public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false)
where TDataSet : IDataSet where TDataSet : IDataSet
{ {
if (!Path.IsPathRooted(saveTo)) if (!Path.IsPathRooted(saveTo))
@@ -49,67 +61,76 @@ namespace Tensorflow.Hub
if (!Path.IsPathRooted(zipFile)) if (!Path.IsPathRooted(zipFile))
zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); zipFile = Path.Combine(AppContext.BaseDirectory, zipFile);


var destFilePath = Path.Combine(saveTo, Path.GetFileNameWithoutExtension(zipFile));
var destFileName = Path.GetFileNameWithoutExtension(zipFile);
var destFilePath = Path.Combine(saveTo, destFileName);

if (showProgressInConsole)
Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}");


if (File.Exists(destFilePath)) if (File.Exists(destFilePath))
File.Delete(destFilePath);
{
if (showProgressInConsole)
Console.WriteLine($"The file {destFileName} already exists");
}


using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress))
{ {
using (var destStream = File.Create(destFilePath)) using (var destStream = File.Create(destFilePath))
{ {
await unzipStream.CopyToAsync(destStream);
await destStream.FlushAsync();
await unzipStream.CopyToAsync(destStream).ConfigureAwait(false);
await destStream.FlushAsync().ConfigureAwait(false);
destStream.Close(); destStream.Close();
} }


unzipStream.Close(); unzipStream.Close();
} }
}

public static async Task ShowProgressInConsole(this Task task)
{
await ShowProgressInConsole(task, true);
}
}


public static async Task ShowProgressInConsole(this Task task, bool enable) public static async Task ShowProgressInConsole(this Task task, bool enable)
{ {
if (!enable) if (!enable)
{ {
await task; await task;
return;
} }


var cts = new CancellationTokenSource(); var cts = new CancellationTokenSource();

var showProgressTask = ShowProgressInConsole(cts); var showProgressTask = ShowProgressInConsole(cts);


try try
{
{
await task; await task;
} }
finally finally
{ {
cts.Cancel();
cts.Cancel();
} }

await showProgressTask;
Console.WriteLine("Done.");
} }


private static async Task ShowProgressInConsole(CancellationTokenSource cts) private static async Task ShowProgressInConsole(CancellationTokenSource cts)
{ {
var cols = 0; var cols = 0;


await Task.Delay(1000);

while (!cts.IsCancellationRequested) while (!cts.IsCancellationRequested)
{ {
await Task.Delay(1000); await Task.Delay(1000);
Console.Write("."); Console.Write(".");
cols++; cols++;


if (cols >= 50)
if (cols % 50 == 0)
{ {
cols = 0;
Console.WriteLine(); Console.WriteLine();
} }
} }


Console.WriteLine();
if (cols > 0)
Console.WriteLine();
} }
} }
} }

+ 2
- 1
test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs View File

@@ -70,7 +70,8 @@ namespace TensorFlowNET.Examples
OneHot = true, OneHot = true,
TrainSize = train_size, TrainSize = train_size,
ValidationSize = validation_size, ValidationSize = validation_size,
TestSize = test_size
TestSize = test_size,
ShowProgressInConsole = true
}; };


mnist = loader.LoadAsync(setting).Result; mnist = loader.LoadAsync(setting).Result;


+ 1
- 1
test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs View File

@@ -124,7 +124,7 @@ namespace TensorFlowNET.Examples


public void PrepareData() public void PrepareData()
{ {
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result;
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size, showProgressInConsole: true).Result;
} }


public void SaveModel(Session sess) public void SaveModel(Session sess)


+ 1
- 1
test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs View File

@@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples


public void PrepareData() public void PrepareData()
{ {
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result;
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize, showProgressInConsole: true).Result;
// In this example, we limit mnist data // In this example, we limit mnist data
(Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) (Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing (Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs View File

@@ -310,7 +310,7 @@ namespace TensorFlowNET.Examples
public void PrepareData() public void PrepareData()
{ {
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result;
(x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); (x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels);
(x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); (x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels);
(x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); (x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels);


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs View File

@@ -121,7 +121,7 @@ namespace TensorFlowNET.Examples
public void PrepareData() public void PrepareData()
{ {
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result;
} }


public void Train(Session sess) public void Train(Session sess)


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs View File

@@ -143,7 +143,7 @@ namespace TensorFlowNET.Examples


public void PrepareData() public void PrepareData()
{ {
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result;
(x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); (x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels);
(x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); (x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels);
(x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels); (x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels);


Loading…
Cancel
Save