You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MnistDataSet.cs 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using NumSharp;
  5. using Tensorflow;
  6. namespace Tensorflow.Hub
  7. {
  8. public class MnistDataSet : DataSetBase
  9. {
  10. public int NumOfExamples { get; private set; }
  11. public int EpochsCompleted { get; private set; }
  12. public int IndexInEpoch { get; private set; }
  13. public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape)
  14. {
  15. EpochsCompleted = 0;
  16. IndexInEpoch = 0;
  17. NumOfExamples = images.shape[0];
  18. images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
  19. images.astype(dataType);
  20. images = np.multiply(images, 1.0f / 255.0f);
  21. Data = images;
  22. labels.astype(dataType);
  23. Labels = labels;
  24. }
  25. public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true)
  26. {
  27. var start = IndexInEpoch;
  28. // Shuffle for the first epoch
  29. if(EpochsCompleted == 0 && start == 0 && shuffle)
  30. {
  31. var perm0 = np.arange(NumOfExamples);
  32. np.random.shuffle(perm0);
  33. Data = Data[perm0];
  34. Labels = Labels[perm0];
  35. }
  36. // Go to the next epoch
  37. if (start + batch_size > NumOfExamples)
  38. {
  39. // Finished epoch
  40. EpochsCompleted += 1;
  41. // Get the rest examples in this epoch
  42. var rest_num_examples = NumOfExamples - start;
  43. //var images_rest_part = _images[np.arange(start, _num_examples)];
  44. //var labels_rest_part = _labels[np.arange(start, _num_examples)];
  45. // Shuffle the data
  46. if (shuffle)
  47. {
  48. var perm = np.arange(NumOfExamples);
  49. np.random.shuffle(perm);
  50. Data = Data[perm];
  51. Labels = Labels[perm];
  52. }
  53. start = 0;
  54. IndexInEpoch = batch_size - rest_num_examples;
  55. var end = IndexInEpoch;
  56. var images_new_part = Data[np.arange(start, end)];
  57. var labels_new_part = Labels[np.arange(start, end)];
  58. /*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0),
  59. np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/
  60. return (images_new_part, labels_new_part);
  61. }
  62. else
  63. {
  64. IndexInEpoch += batch_size;
  65. var end = IndexInEpoch;
  66. return (Data[np.arange(start, end)], Labels[np.arange(start, end)]);
  67. }
  68. }
  69. }
  70. }