You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MnistModelLoader.cs 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. using System;
  2. using System.Threading.Tasks;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using System.IO;
  6. using NumSharp;
  7. namespace Tensorflow.Hub
  8. {
  9. public class MnistModelLoader : IModelLoader<MnistDataSet>
  10. {
  11. private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
  12. private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
  13. private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
  14. private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
  15. private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
  16. public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false)
  17. {
  18. var loader = new MnistModelLoader();
  19. return await loader.LoadAsync(new ModelLoadSetting
  20. {
  21. TrainDir = trainDir,
  22. OneHot = oneHot
  23. });
  24. }
  25. public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)
  26. {
  27. if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value)
  28. throw new ArgumentException("Validation set should be smaller than training set");
  29. var sourceUrl = setting.SourceUrl;
  30. if (string.IsNullOrEmpty(sourceUrl))
  31. sourceUrl = DEFAULT_SOURCE_URL;
  32. // load train images
  33. await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES)
  34. .ShowProgressInConsole(setting.ShowProgressInConsole);
  35. await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir)
  36. .ShowProgressInConsole(setting.ShowProgressInConsole);
  37. var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize);
  38. // load train labels
  39. await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS)
  40. .ShowProgressInConsole(setting.ShowProgressInConsole);
  41. await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir)
  42. .ShowProgressInConsole(setting.ShowProgressInConsole);
  43. var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize);
  44. // load test images
  45. await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES)
  46. .ShowProgressInConsole(setting.ShowProgressInConsole);
  47. await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir)
  48. .ShowProgressInConsole(setting.ShowProgressInConsole);
  49. var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize);
  50. // load test labels
  51. await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS)
  52. .ShowProgressInConsole(setting.ShowProgressInConsole);
  53. await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir)
  54. .ShowProgressInConsole(setting.ShowProgressInConsole);
  55. var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize);
  56. var end = trainImages.shape[0];
  57. var validationSize = setting.ValidationSize;
  58. var validationImages = trainImages[np.arange(validationSize)];
  59. var validationLabels = trainLabels[np.arange(validationSize)];
  60. trainImages = trainImages[np.arange(validationSize, end)];
  61. trainLabels = trainLabels[np.arange(validationSize, end)];
  62. var dtype = setting.DtType;
  63. var reshape = setting.ReShape;
  64. var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape);
  65. var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape);
  66. var test = new MnistDataSet(trainImages, trainLabels, dtype, reshape);
  67. return new Datasets<MnistDataSet>(train, validation, test);
  68. }
  69. private NDArray ExtractImages(string file, int? limit = null)
  70. {
  71. using (var bytestream = new FileStream(file, FileMode.Open))
  72. {
  73. var magic = Read32(bytestream);
  74. if (magic != 2051)
  75. throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
  76. var num_images = Read32(bytestream);
  77. num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);
  78. var rows = Read32(bytestream);
  79. var cols = Read32(bytestream);
  80. var buf = new byte[rows * cols * num_images];
  81. bytestream.Read(buf, 0, buf.Length);
  82. var data = np.frombuffer(buf, np.uint8);
  83. data = data.reshape((int)num_images, (int)rows, (int)cols, 1);
  84. return data;
  85. }
  86. }
  87. private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null)
  88. {
  89. using (var bytestream = new FileStream(file, FileMode.Open))
  90. {
  91. var magic = Read32(bytestream);
  92. if (magic != 2049)
  93. throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
  94. var num_items = Read32(bytestream);
  95. num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit);
  96. var buf = new byte[num_items];
  97. bytestream.Read(buf, 0, buf.Length);
  98. var labels = np.frombuffer(buf, np.uint8);
  99. if (one_hot)
  100. return DenseToOneHot(labels, num_classes);
  101. return labels;
  102. }
  103. }
  104. private NDArray DenseToOneHot(NDArray labels_dense, int num_classes)
  105. {
  106. var num_labels = labels_dense.shape[0];
  107. var index_offset = np.arange(num_labels) * num_classes;
  108. var labels_one_hot = np.zeros(num_labels, num_classes);
  109. for(int row = 0; row < num_labels; row++)
  110. {
  111. var col = labels_dense.Data<byte>(row);
  112. labels_one_hot.SetData(1.0, row, col);
  113. }
  114. return labels_one_hot;
  115. }
  116. private uint Read32(FileStream bytestream)
  117. {
  118. var buffer = new byte[sizeof(uint)];
  119. var count = bytestream.Read(buffer, 0, 4);
  120. return np.frombuffer(buffer, ">u4").Data<uint>(0);
  121. }
  122. }
  123. }