Browse Source

added method GetNextBatch to MnistDataSet

pull/327/head
Kerry Jiang 6 years ago
parent
commit
2f2170e50f
1 changed files with 49 additions and 0 deletions
  1. +49
    -0
      src/TensorFlowHub/MnistDataSet.cs

+ 49
- 0
src/TensorFlowHub/MnistDataSet.cs View File

@@ -27,5 +27,54 @@ namespace Tensorflow.Hub
labels.astype(dataType);
Labels = labels;
}

public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true)
{
var start = IndexInEpoch;
// Shuffle for the first epoch
if(EpochsCompleted == 0 && start == 0 && shuffle)
{
var perm0 = np.arange(NumOfExamples);
np.random.shuffle(perm0);
Data = Data[perm0];
Labels = Labels[perm0];
}

// Go to the next epoch
if (start + batch_size > NumOfExamples)
{
// Finished epoch
EpochsCompleted += 1;

// Get the rest examples in this epoch
var rest_num_examples = NumOfExamples - start;
//var images_rest_part = _images[np.arange(start, _num_examples)];
//var labels_rest_part = _labels[np.arange(start, _num_examples)];
// Shuffle the data
if (shuffle)
{
var perm = np.arange(NumOfExamples);
np.random.shuffle(perm);
Data = Data[perm];
Labels = Labels[perm];
}

start = 0;
IndexInEpoch = batch_size - rest_num_examples;
var end = IndexInEpoch;
var images_new_part = Data[np.arange(start, end)];
var labels_new_part = Labels[np.arange(start, end)];

/*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0),
np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/
return (images_new_part, labels_new_part);
}
else
{
IndexInEpoch += batch_size;
var end = IndexInEpoch;
return (Data[np.arange(start, end)], Labels[np.arange(start, end)]);
}
}
}
}

Loading…
Cancel
Save