| @@ -15,14 +15,15 @@ namespace Tensorflow.Hub | |||||
| private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | ||||
| private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | ||||
| public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null) | |||||
| public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||||
| { | { | ||||
| var loader = new MnistModelLoader(); | var loader = new MnistModelLoader(); | ||||
| var setting = new ModelLoadSetting | var setting = new ModelLoadSetting | ||||
| { | { | ||||
| TrainDir = trainDir, | TrainDir = trainDir, | ||||
| OneHot = oneHot | |||||
| OneHot = oneHot, | |||||
| ShowProgressInConsole = showProgressInConsole | |||||
| }; | }; | ||||
| if (trainSize.HasValue) | if (trainSize.HasValue) | ||||
| @@ -48,37 +49,37 @@ namespace Tensorflow.Hub | |||||
| sourceUrl = DEFAULT_SOURCE_URL; | sourceUrl = DEFAULT_SOURCE_URL; | ||||
| // load train images | // load train images | ||||
| await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES) | |||||
| await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir) | |||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); | var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); | ||||
| // load train labels | // load train labels | ||||
| await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS) | |||||
| await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir) | |||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); | var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); | ||||
| // load test images | // load test images | ||||
| await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES) | |||||
| await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir) | |||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); | var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); | ||||
| // load test labels | // load test labels | ||||
| await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS) | |||||
| await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir) | |||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); | var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); | ||||
| @@ -19,7 +19,7 @@ namespace Tensorflow.Hub | |||||
| await modelLoader.DownloadAsync(url, dir, fileName); | await modelLoader.DownloadAsync(url, dir, fileName); | ||||
| } | } | ||||
| public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName) | |||||
| public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false) | |||||
| where TDataSet : IDataSet | where TDataSet : IDataSet | ||||
| { | { | ||||
| if (!Path.IsPathRooted(dirSaveTo)) | if (!Path.IsPathRooted(dirSaveTo)) | ||||
| @@ -27,18 +27,30 @@ namespace Tensorflow.Hub | |||||
| var fileSaveTo = Path.Combine(dirSaveTo, fileName); | var fileSaveTo = Path.Combine(dirSaveTo, fileName); | ||||
| if (showProgressInConsole) | |||||
| { | |||||
| Console.WriteLine($"Downloading {fileName}"); | |||||
| } | |||||
| if (File.Exists(fileSaveTo)) | if (File.Exists(fileSaveTo)) | ||||
| { | |||||
| if (showProgressInConsole) | |||||
| { | |||||
| Console.WriteLine($"The file {fileName} already exists"); | |||||
| } | |||||
| return; | return; | ||||
| } | |||||
| Directory.CreateDirectory(dirSaveTo); | Directory.CreateDirectory(dirSaveTo); | ||||
| using (var wc = new WebClient()) | using (var wc = new WebClient()) | ||||
| { | { | ||||
| await wc.DownloadFileTaskAsync(url, fileSaveTo); | |||||
| await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false); | |||||
| } | } | ||||
| } | } | ||||
| public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo) | |||||
| public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false) | |||||
| where TDataSet : IDataSet | where TDataSet : IDataSet | ||||
| { | { | ||||
| if (!Path.IsPathRooted(saveTo)) | if (!Path.IsPathRooted(saveTo)) | ||||
| @@ -49,67 +61,76 @@ namespace Tensorflow.Hub | |||||
| if (!Path.IsPathRooted(zipFile)) | if (!Path.IsPathRooted(zipFile)) | ||||
| zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); | zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); | ||||
| var destFilePath = Path.Combine(saveTo, Path.GetFileNameWithoutExtension(zipFile)); | |||||
| var destFileName = Path.GetFileNameWithoutExtension(zipFile); | |||||
| var destFilePath = Path.Combine(saveTo, destFileName); | |||||
| if (showProgressInConsole) | |||||
| Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); | |||||
| if (File.Exists(destFilePath)) | if (File.Exists(destFilePath)) | ||||
| File.Delete(destFilePath); | |||||
| { | |||||
| if (showProgressInConsole) | |||||
| Console.WriteLine($"The file {destFileName} already exists"); | |||||
| } | |||||
| using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) | using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) | ||||
| { | { | ||||
| using (var destStream = File.Create(destFilePath)) | using (var destStream = File.Create(destFilePath)) | ||||
| { | { | ||||
| await unzipStream.CopyToAsync(destStream); | |||||
| await destStream.FlushAsync(); | |||||
| await unzipStream.CopyToAsync(destStream).ConfigureAwait(false); | |||||
| await destStream.FlushAsync().ConfigureAwait(false); | |||||
| destStream.Close(); | destStream.Close(); | ||||
| } | } | ||||
| unzipStream.Close(); | unzipStream.Close(); | ||||
| } | } | ||||
| } | |||||
| public static async Task ShowProgressInConsole(this Task task) | |||||
| { | |||||
| await ShowProgressInConsole(task, true); | |||||
| } | |||||
| } | |||||
| public static async Task ShowProgressInConsole(this Task task, bool enable) | public static async Task ShowProgressInConsole(this Task task, bool enable) | ||||
| { | { | ||||
| if (!enable) | if (!enable) | ||||
| { | { | ||||
| await task; | await task; | ||||
| return; | |||||
| } | } | ||||
| var cts = new CancellationTokenSource(); | var cts = new CancellationTokenSource(); | ||||
| var showProgressTask = ShowProgressInConsole(cts); | var showProgressTask = ShowProgressInConsole(cts); | ||||
| try | try | ||||
| { | |||||
| { | |||||
| await task; | await task; | ||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| cts.Cancel(); | |||||
| cts.Cancel(); | |||||
| } | } | ||||
| await showProgressTask; | |||||
| Console.WriteLine("Done."); | |||||
| } | } | ||||
| private static async Task ShowProgressInConsole(CancellationTokenSource cts) | private static async Task ShowProgressInConsole(CancellationTokenSource cts) | ||||
| { | { | ||||
| var cols = 0; | var cols = 0; | ||||
| await Task.Delay(1000); | |||||
| while (!cts.IsCancellationRequested) | while (!cts.IsCancellationRequested) | ||||
| { | { | ||||
| await Task.Delay(1000); | await Task.Delay(1000); | ||||
| Console.Write("."); | Console.Write("."); | ||||
| cols++; | cols++; | ||||
| if (cols >= 50) | |||||
| if (cols % 50 == 0) | |||||
| { | { | ||||
| cols = 0; | |||||
| Console.WriteLine(); | Console.WriteLine(); | ||||
| } | } | ||||
| } | } | ||||
| Console.WriteLine(); | |||||
| if (cols > 0) | |||||
| Console.WriteLine(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -70,7 +70,8 @@ namespace TensorFlowNET.Examples | |||||
| OneHot = true, | OneHot = true, | ||||
| TrainSize = train_size, | TrainSize = train_size, | ||||
| ValidationSize = validation_size, | ValidationSize = validation_size, | ||||
| TestSize = test_size | |||||
| TestSize = test_size, | |||||
| ShowProgressInConsole = true | |||||
| }; | }; | ||||
| mnist = loader.LoadAsync(setting).Result; | mnist = loader.LoadAsync(setting).Result; | ||||
| @@ -124,7 +124,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result; | |||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size, showProgressInConsole: true).Result; | |||||
| } | } | ||||
| public void SaveModel(Session sess) | public void SaveModel(Session sess) | ||||
| @@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result; | |||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize, showProgressInConsole: true).Result; | |||||
| // In this example, we limit mnist data | // In this example, we limit mnist data | ||||
| (Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | (Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | ||||
| (Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing | (Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing | ||||
| @@ -310,7 +310,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||||
| (x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); | (x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); | ||||
| (x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); | (x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); | ||||
| (x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); | (x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); | ||||
| @@ -121,7 +121,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||||
| } | } | ||||
| public void Train(Session sess) | public void Train(Session sess) | ||||
| @@ -143,7 +143,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||||
| (x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); | (x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); | ||||
| (x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); | (x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); | ||||
| (x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels); | (x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels); | ||||