|
|
|
@@ -27,5 +27,54 @@ namespace Tensorflow.Hub |
|
|
|
labels.astype(dataType); |
|
|
|
Labels = labels; |
|
|
|
} |
|
|
|
|
|
|
|
public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true) |
|
|
|
{ |
|
|
|
var start = IndexInEpoch; |
|
|
|
// Shuffle for the first epoch |
|
|
|
if(EpochsCompleted == 0 && start == 0 && shuffle) |
|
|
|
{ |
|
|
|
var perm0 = np.arange(NumOfExamples); |
|
|
|
np.random.shuffle(perm0); |
|
|
|
Data = Data[perm0]; |
|
|
|
Labels = Labels[perm0]; |
|
|
|
} |
|
|
|
|
|
|
|
// Go to the next epoch |
|
|
|
if (start + batch_size > NumOfExamples) |
|
|
|
{ |
|
|
|
// Finished epoch |
|
|
|
EpochsCompleted += 1; |
|
|
|
|
|
|
|
// Get the rest examples in this epoch |
|
|
|
var rest_num_examples = NumOfExamples - start; |
|
|
|
//var images_rest_part = _images[np.arange(start, _num_examples)]; |
|
|
|
//var labels_rest_part = _labels[np.arange(start, _num_examples)]; |
|
|
|
// Shuffle the data |
|
|
|
if (shuffle) |
|
|
|
{ |
|
|
|
var perm = np.arange(NumOfExamples); |
|
|
|
np.random.shuffle(perm); |
|
|
|
Data = Data[perm]; |
|
|
|
Labels = Labels[perm]; |
|
|
|
} |
|
|
|
|
|
|
|
start = 0; |
|
|
|
IndexInEpoch = batch_size - rest_num_examples; |
|
|
|
var end = IndexInEpoch; |
|
|
|
var images_new_part = Data[np.arange(start, end)]; |
|
|
|
var labels_new_part = Labels[np.arange(start, end)]; |
|
|
|
|
|
|
|
/*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0), |
|
|
|
np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/ |
|
|
|
return (images_new_part, labels_new_part); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
IndexInEpoch += batch_size; |
|
|
|
var end = IndexInEpoch; |
|
|
|
return (Data[np.arange(start, end)], Labels[np.arange(start, end)]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |