You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Datasets.cs 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using NumSharp;
  5. namespace Tensorflow.Hub
  6. {
  7. public class Datasets<TDataSet> where TDataSet : IDataSet
  8. {
  9. public TDataSet Train { get; private set; }
  10. public TDataSet Validation { get; private set; }
  11. public TDataSet Test { get; private set; }
  12. public Datasets(TDataSet train, TDataSet validation, TDataSet test)
  13. {
  14. Train = train;
  15. Validation = validation;
  16. Test = test;
  17. }
  18. public (NDArray, NDArray) Randomize(NDArray x, NDArray y)
  19. {
  20. var perm = np.random.permutation(y.shape[0]);
  21. np.random.shuffle(perm);
  22. return (x[perm], y[perm]);
  23. }
  24. /// <summary>
  25. /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
  26. /// </summary>
  27. /// <param name="x"></param>
  28. /// <param name="y"></param>
  29. /// <param name="start"></param>
  30. /// <param name="end"></param>
  31. /// <returns></returns>
  32. public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
  33. {
  34. var slice = new Slice(start, end);
  35. var x_batch = x[slice];
  36. var y_batch = y[slice];
  37. return (x_batch, y_batch);
  38. }
  39. }
  40. }