diff --git a/src/TensorFlowHub/MnistDataSet.cs b/src/TensorFlowHub/MnistDataSet.cs index e0717ccb..accc57e1 100644 --- a/src/TensorFlowHub/MnistDataSet.cs +++ b/src/TensorFlowHub/MnistDataSet.cs @@ -27,5 +27,54 @@ namespace Tensorflow.Hub labels.astype(dataType); Labels = labels; } + + public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true) + { + var start = IndexInEpoch; + // Shuffle for the first epoch + if(EpochsCompleted == 0 && start == 0 && shuffle) + { + var perm0 = np.arange(NumOfExamples); + np.random.shuffle(perm0); + Data = Data[perm0]; + Labels = Labels[perm0]; + } + + // Go to the next epoch + if (start + batch_size > NumOfExamples) + { + // Finished epoch + EpochsCompleted += 1; + + // Get the rest examples in this epoch + var rest_num_examples = NumOfExamples - start; + //var images_rest_part = _images[np.arange(start, _num_examples)]; + //var labels_rest_part = _labels[np.arange(start, _num_examples)]; + // Shuffle the data + if (shuffle) + { + var perm = np.arange(NumOfExamples); + np.random.shuffle(perm); + Data = Data[perm]; + Labels = Labels[perm]; + } + + start = 0; + IndexInEpoch = batch_size - rest_num_examples; + var end = IndexInEpoch; + var images_new_part = Data[np.arange(start, end)]; + var labels_new_part = Labels[np.arange(start, end)]; + + /*return (np.concatenate(new float[][] { images_rest_part.Data(), images_new_part.Data() }, axis: 0), + np.concatenate(new float[][] { labels_rest_part.Data(), labels_new_part.Data() }, axis: 0));*/ + return (images_new_part, labels_new_part); + } + else + { + IndexInEpoch += batch_size; + var end = IndexInEpoch; + return (Data[np.arange(start, end)], Labels[np.arange(start, end)]); + } + } } }