From a36d0602e2b21c0e860d9a989238f95047c987ef Mon Sep 17 00:00:00 2001 From: Kerry Jiang Date: Wed, 7 Aug 2019 00:53:07 -0700 Subject: [PATCH] fixed issues about displaying progress in console when download/unzip resources --- src/TensorFlowHub/MnistModelLoader.cs | 21 +++---- src/TensorFlowHub/Utils.cs | 57 +++++++++++++------ .../BasicModels/KMeansClustering.cs | 3 +- .../BasicModels/LogisticRegression.cs | 2 +- .../BasicModels/NearestNeighbor.cs | 2 +- .../ImageProcessing/DigitRecognitionCNN.cs | 2 +- .../ImageProcessing/DigitRecognitionNN.cs | 2 +- .../ImageProcessing/DigitRecognitionRNN.cs | 2 +- 8 files changed, 57 insertions(+), 34 deletions(-) diff --git a/src/TensorFlowHub/MnistModelLoader.cs b/src/TensorFlowHub/MnistModelLoader.cs index 121c0961..3a9fabb2 100644 --- a/src/TensorFlowHub/MnistModelLoader.cs +++ b/src/TensorFlowHub/MnistModelLoader.cs @@ -15,14 +15,15 @@ namespace Tensorflow.Hub private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; - public static async Task> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null) + public static async Task> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) { var loader = new MnistModelLoader(); var setting = new ModelLoadSetting { TrainDir = trainDir, - OneHot = oneHot + OneHot = oneHot, + ShowProgressInConsole = showProgressInConsole }; if (trainSize.HasValue) @@ -48,37 +49,37 @@ namespace Tensorflow.Hub sourceUrl = DEFAULT_SOURCE_URL; // load train images - await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES) + await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) .ShowProgressInConsole(setting.ShowProgressInConsole); - await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir) + await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) .ShowProgressInConsole(setting.ShowProgressInConsole); var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); // load train labels - await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS) + await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole) .ShowProgressInConsole(setting.ShowProgressInConsole); - await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir) + await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) .ShowProgressInConsole(setting.ShowProgressInConsole); var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); // load test images - await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES) + await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) .ShowProgressInConsole(setting.ShowProgressInConsole); - await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir) + await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) .ShowProgressInConsole(setting.ShowProgressInConsole); var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); // load test labels - await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS) + await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole) .ShowProgressInConsole(setting.ShowProgressInConsole); - await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir) + await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) .ShowProgressInConsole(setting.ShowProgressInConsole); var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); diff --git a/src/TensorFlowHub/Utils.cs b/src/TensorFlowHub/Utils.cs index 3245071f..46f94f35 100644 --- a/src/TensorFlowHub/Utils.cs +++ b/src/TensorFlowHub/Utils.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Hub await modelLoader.DownloadAsync(url, dir, fileName); } - public static async Task DownloadAsync(this IModelLoader modelLoader, string url, string dirSaveTo, string fileName) + public static async Task DownloadAsync(this IModelLoader modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false) where TDataSet : IDataSet { if (!Path.IsPathRooted(dirSaveTo)) @@ -27,18 +27,30 @@ namespace Tensorflow.Hub var fileSaveTo = Path.Combine(dirSaveTo, fileName); + if (showProgressInConsole) + { + Console.WriteLine($"Downloading {fileName}"); + } + if (File.Exists(fileSaveTo)) + { + if (showProgressInConsole) + { + Console.WriteLine($"The file {fileName} already exists"); + } + return; + } Directory.CreateDirectory(dirSaveTo); using (var wc = new WebClient()) { - await wc.DownloadFileTaskAsync(url, fileSaveTo); + await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false); } } - public static async Task UnzipAsync(this IModelLoader modelLoader, string zipFile, string saveTo) + public static async Task UnzipAsync(this IModelLoader modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false) where TDataSet : IDataSet { if (!Path.IsPathRooted(saveTo)) @@ -49,67 +61,76 @@ namespace Tensorflow.Hub if (!Path.IsPathRooted(zipFile)) zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); - var destFilePath = Path.Combine(saveTo, Path.GetFileNameWithoutExtension(zipFile)); + var destFileName = Path.GetFileNameWithoutExtension(zipFile); + var destFilePath = Path.Combine(saveTo, destFileName); + + if (showProgressInConsole) + Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); if (File.Exists(destFilePath)) - File.Delete(destFilePath); + { + if (showProgressInConsole) + Console.WriteLine($"The file {destFileName} already exists"); + } using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) { using (var destStream = File.Create(destFilePath)) { - await unzipStream.CopyToAsync(destStream); - await destStream.FlushAsync(); + await unzipStream.CopyToAsync(destStream).ConfigureAwait(false); + await destStream.FlushAsync().ConfigureAwait(false); destStream.Close(); } unzipStream.Close(); } - } - - public static async Task ShowProgressInConsole(this Task task) - { - await ShowProgressInConsole(task, true); - } + } public static async Task ShowProgressInConsole(this Task task, bool enable) { if (!enable) { await task; + return; } var cts = new CancellationTokenSource(); + var showProgressTask = ShowProgressInConsole(cts); try - { + { await task; } finally { - cts.Cancel(); + cts.Cancel(); } + + await showProgressTask; + Console.WriteLine("Done."); } private static async Task ShowProgressInConsole(CancellationTokenSource cts) { var cols = 0; + await Task.Delay(1000); + while (!cts.IsCancellationRequested) { await Task.Delay(1000); Console.Write("."); cols++; - if (cols >= 50) + if (cols % 50 == 0) { - cols = 0; Console.WriteLine(); } } - Console.WriteLine(); + if (cols > 0) + Console.WriteLine(); } } } diff --git a/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs b/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs index 7bacb28d..9221d68c 100644 --- a/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs +++ b/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs @@ -70,7 +70,8 @@ namespace TensorFlowNET.Examples OneHot = true, TrainSize = train_size, ValidationSize = validation_size, - TestSize = test_size + TestSize = test_size, + ShowProgressInConsole = true }; mnist = loader.LoadAsync(setting).Result; diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index ca691d40..263023ef 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -124,7 +124,7 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result; + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size, showProgressInConsole: true).Result; } public void SaveModel(Session sess) diff --git a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs index 8f761d00..7ae34364 100644 --- a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs +++ b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs @@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result; + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize, showProgressInConsole: true).Result; // In this example, we limit mnist data (Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) (Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs index 4b882a1a..dd2cc756 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs @@ -310,7 +310,7 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; (x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); (x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); (x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs index 02feecb9..49bbc680 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs @@ -121,7 +121,7 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; } public void Train(Session sess) diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs index b91a19ca..07df8e6a 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs @@ -143,7 +143,7 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; (x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); (x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); (x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels);