| @@ -15,14 +15,15 @@ namespace Tensorflow.Hub | |||
| private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | |||
| private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | |||
| public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null) | |||
| public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||
| { | |||
| var loader = new MnistModelLoader(); | |||
| var setting = new ModelLoadSetting | |||
| { | |||
| TrainDir = trainDir, | |||
| OneHot = oneHot | |||
| OneHot = oneHot, | |||
| ShowProgressInConsole = showProgressInConsole | |||
| }; | |||
| if (trainSize.HasValue) | |||
| @@ -48,37 +49,37 @@ namespace Tensorflow.Hub | |||
| sourceUrl = DEFAULT_SOURCE_URL; | |||
| // load train images | |||
| await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES) | |||
| await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir) | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); | |||
| // load train labels | |||
| await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS) | |||
| await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir) | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); | |||
| // load test images | |||
| await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES) | |||
| await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir) | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); | |||
| // load test labels | |||
| await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS) | |||
| await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir) | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); | |||
| @@ -19,7 +19,7 @@ namespace Tensorflow.Hub | |||
| await modelLoader.DownloadAsync(url, dir, fileName); | |||
| } | |||
| public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName) | |||
| public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false) | |||
| where TDataSet : IDataSet | |||
| { | |||
| if (!Path.IsPathRooted(dirSaveTo)) | |||
| @@ -27,18 +27,30 @@ namespace Tensorflow.Hub | |||
| var fileSaveTo = Path.Combine(dirSaveTo, fileName); | |||
| if (showProgressInConsole) | |||
| { | |||
| Console.WriteLine($"Downloading {fileName}"); | |||
| } | |||
| if (File.Exists(fileSaveTo)) | |||
| { | |||
| if (showProgressInConsole) | |||
| { | |||
| Console.WriteLine($"The file {fileName} already exists"); | |||
| } | |||
| return; | |||
| } | |||
| Directory.CreateDirectory(dirSaveTo); | |||
| using (var wc = new WebClient()) | |||
| { | |||
| await wc.DownloadFileTaskAsync(url, fileSaveTo); | |||
| await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false); | |||
| } | |||
| } | |||
| public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo) | |||
| public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false) | |||
| where TDataSet : IDataSet | |||
| { | |||
| if (!Path.IsPathRooted(saveTo)) | |||
| @@ -49,67 +61,76 @@ namespace Tensorflow.Hub | |||
| if (!Path.IsPathRooted(zipFile)) | |||
| zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); | |||
| var destFilePath = Path.Combine(saveTo, Path.GetFileNameWithoutExtension(zipFile)); | |||
| var destFileName = Path.GetFileNameWithoutExtension(zipFile); | |||
| var destFilePath = Path.Combine(saveTo, destFileName); | |||
| if (showProgressInConsole) | |||
| Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); | |||
| if (File.Exists(destFilePath)) | |||
| File.Delete(destFilePath); | |||
| { | |||
| if (showProgressInConsole) | |||
| Console.WriteLine($"The file {destFileName} already exists"); | |||
| } | |||
| using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) | |||
| { | |||
| using (var destStream = File.Create(destFilePath)) | |||
| { | |||
| await unzipStream.CopyToAsync(destStream); | |||
| await destStream.FlushAsync(); | |||
| await unzipStream.CopyToAsync(destStream).ConfigureAwait(false); | |||
| await destStream.FlushAsync().ConfigureAwait(false); | |||
| destStream.Close(); | |||
| } | |||
| unzipStream.Close(); | |||
| } | |||
| } | |||
| public static async Task ShowProgressInConsole(this Task task) | |||
| { | |||
| await ShowProgressInConsole(task, true); | |||
| } | |||
| } | |||
| public static async Task ShowProgressInConsole(this Task task, bool enable) | |||
| { | |||
| if (!enable) | |||
| { | |||
| await task; | |||
| return; | |||
| } | |||
| var cts = new CancellationTokenSource(); | |||
| var showProgressTask = ShowProgressInConsole(cts); | |||
| try | |||
| { | |||
| { | |||
| await task; | |||
| } | |||
| finally | |||
| { | |||
| cts.Cancel(); | |||
| cts.Cancel(); | |||
| } | |||
| await showProgressTask; | |||
| Console.WriteLine("Done."); | |||
| } | |||
| private static async Task ShowProgressInConsole(CancellationTokenSource cts) | |||
| { | |||
| var cols = 0; | |||
| await Task.Delay(1000); | |||
| while (!cts.IsCancellationRequested) | |||
| { | |||
| await Task.Delay(1000); | |||
| Console.Write("."); | |||
| cols++; | |||
| if (cols >= 50) | |||
| if (cols % 50 == 0) | |||
| { | |||
| cols = 0; | |||
| Console.WriteLine(); | |||
| } | |||
| } | |||
| Console.WriteLine(); | |||
| if (cols > 0) | |||
| Console.WriteLine(); | |||
| } | |||
| } | |||
| } | |||
| @@ -70,7 +70,8 @@ namespace TensorFlowNET.Examples | |||
| OneHot = true, | |||
| TrainSize = train_size, | |||
| ValidationSize = validation_size, | |||
| TestSize = test_size | |||
| TestSize = test_size, | |||
| ShowProgressInConsole = true | |||
| }; | |||
| mnist = loader.LoadAsync(setting).Result; | |||
| @@ -124,7 +124,7 @@ namespace TensorFlowNET.Examples | |||
| public void PrepareData() | |||
| { | |||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result; | |||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size, showProgressInConsole: true).Result; | |||
| } | |||
| public void SaveModel(Session sess) | |||
| @@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples | |||
| public void PrepareData() | |||
| { | |||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result; | |||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize, showProgressInConsole: true).Result; | |||
| // In this example, we limit mnist data | |||
| (Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | |||
| (Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing | |||
| @@ -310,7 +310,7 @@ namespace TensorFlowNET.Examples | |||
| public void PrepareData() | |||
| { | |||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||
| (x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); | |||
| (x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); | |||
| (x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); | |||
| @@ -121,7 +121,7 @@ namespace TensorFlowNET.Examples | |||
| public void PrepareData() | |||
| { | |||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||
| } | |||
| public void Train(Session sess) | |||
| @@ -143,7 +143,7 @@ namespace TensorFlowNET.Examples | |||
| public void PrepareData() | |||
| { | |||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||
| (x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); | |||
| (x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); | |||
| (x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels); | |||