You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MnistModelLoader.cs 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. using System;
  2. using System.Threading.Tasks;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using System.IO;
  6. using NumSharp;
  7. namespace Tensorflow.Hub
  8. {
  9. public class MnistModelLoader : IModelLoader<MnistDataSet>
  10. {
  11. private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
  12. private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
  13. private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
  14. private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
  15. private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
  16. public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null)
  17. {
  18. var loader = new MnistModelLoader();
  19. var setting = new ModelLoadSetting
  20. {
  21. TrainDir = trainDir,
  22. OneHot = oneHot
  23. };
  24. if (trainSize.HasValue)
  25. setting.TrainSize = trainSize.Value;
  26. if (validationSize.HasValue)
  27. setting.ValidationSize = validationSize.Value;
  28. if (testSize.HasValue)
  29. setting.TestSize = testSize.Value;
  30. return await loader.LoadAsync(setting);
  31. }
  32. public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)
  33. {
  34. if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value)
  35. throw new ArgumentException("Validation set should be smaller than training set");
  36. var sourceUrl = setting.SourceUrl;
  37. if (string.IsNullOrEmpty(sourceUrl))
  38. sourceUrl = DEFAULT_SOURCE_URL;
  39. // load train images
  40. await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES)
  41. .ShowProgressInConsole(setting.ShowProgressInConsole);
  42. await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir)
  43. .ShowProgressInConsole(setting.ShowProgressInConsole);
  44. var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize);
  45. // load train labels
  46. await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS)
  47. .ShowProgressInConsole(setting.ShowProgressInConsole);
  48. await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir)
  49. .ShowProgressInConsole(setting.ShowProgressInConsole);
  50. var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize);
  51. // load test images
  52. await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES)
  53. .ShowProgressInConsole(setting.ShowProgressInConsole);
  54. await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir)
  55. .ShowProgressInConsole(setting.ShowProgressInConsole);
  56. var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize);
  57. // load test labels
  58. await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS)
  59. .ShowProgressInConsole(setting.ShowProgressInConsole);
  60. await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir)
  61. .ShowProgressInConsole(setting.ShowProgressInConsole);
  62. var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize);
  63. var end = trainImages.shape[0];
  64. var validationSize = setting.ValidationSize;
  65. var validationImages = trainImages[np.arange(validationSize)];
  66. var validationLabels = trainLabels[np.arange(validationSize)];
  67. trainImages = trainImages[np.arange(validationSize, end)];
  68. trainLabels = trainLabels[np.arange(validationSize, end)];
  69. var dtype = setting.DataType;
  70. var reshape = setting.ReShape;
  71. var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape);
  72. var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape);
  73. var test = new MnistDataSet(testImages, testLabels, dtype, reshape);
  74. return new Datasets<MnistDataSet>(train, validation, test);
  75. }
  76. private NDArray ExtractImages(string file, int? limit = null)
  77. {
  78. if (!Path.IsPathRooted(file))
  79. file = Path.Combine(AppContext.BaseDirectory, file);
  80. using (var bytestream = new FileStream(file, FileMode.Open))
  81. {
  82. var magic = Read32(bytestream);
  83. if (magic != 2051)
  84. throw new Exception($"Invalid magic number {magic} in MNIST image file: {file}");
  85. var num_images = Read32(bytestream);
  86. num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);
  87. var rows = Read32(bytestream);
  88. var cols = Read32(bytestream);
  89. var buf = new byte[rows * cols * num_images];
  90. bytestream.Read(buf, 0, buf.Length);
  91. var data = np.frombuffer(buf, np.uint8);
  92. data = data.reshape((int)num_images, (int)rows, (int)cols, 1);
  93. return data;
  94. }
  95. }
  96. private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null)
  97. {
  98. if (!Path.IsPathRooted(file))
  99. file = Path.Combine(AppContext.BaseDirectory, file);
  100. using (var bytestream = new FileStream(file, FileMode.Open))
  101. {
  102. var magic = Read32(bytestream);
  103. if (magic != 2049)
  104. throw new Exception($"Invalid magic number {magic} in MNIST label file: {file}");
  105. var num_items = Read32(bytestream);
  106. num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit);
  107. var buf = new byte[num_items];
  108. bytestream.Read(buf, 0, buf.Length);
  109. var labels = np.frombuffer(buf, np.uint8);
  110. if (one_hot)
  111. return DenseToOneHot(labels, num_classes);
  112. return labels;
  113. }
  114. }
  115. private NDArray DenseToOneHot(NDArray labels_dense, int num_classes)
  116. {
  117. var num_labels = labels_dense.shape[0];
  118. var index_offset = np.arange(num_labels) * num_classes;
  119. var labels_one_hot = np.zeros(num_labels, num_classes);
  120. for(int row = 0; row < num_labels; row++)
  121. {
  122. var col = labels_dense.Data<byte>(row);
  123. labels_one_hot.SetData(1.0, row, col);
  124. }
  125. return labels_one_hot;
  126. }
  127. private uint Read32(FileStream bytestream)
  128. {
  129. var buffer = new byte[sizeof(uint)];
  130. var count = bytestream.Read(buffer, 0, 4);
  131. return np.frombuffer(buffer, ">u4").Data<uint>(0);
  132. }
  133. }
  134. }