You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paddle_data.py 1.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import paddle
  2. from paddle.io import Dataset
  3. import numpy as np
  4. class PaddleNormalDataset(Dataset):
  5. def __init__(self, num_of_data=1000):
  6. self.num_of_data = num_of_data
  7. self._data = list(range(num_of_data))
  8. def __len__(self):
  9. return self.num_of_data
  10. def __getitem__(self, item):
  11. return self._data[item]
  12. class PaddleRandomDataset(Dataset):
  13. def __init__(self, num_samples, num_features):
  14. self.x = paddle.randn((num_samples, num_features))
  15. self.y = self.x.argmax(axis=-1)
  16. def __len__(self):
  17. return len(self.x)
  18. def __getitem__(self, item):
  19. return {"x": self.x[item], "y": self.y[item]}
  20. class PaddleDataset_MNIST(Dataset):
  21. def __init__(self, mode="train"):
  22. self.dataset = [
  23. (
  24. np.array(img).astype('float32').reshape(-1),
  25. label
  26. ) for img, label in paddle.vision.datasets.MNIST(mode=mode)
  27. ]
  28. def __getitem__(self, idx):
  29. return {"x": self.dataset[idx][0], "y": self.dataset[idx][1]}
  30. def __len__(self):
  31. return len(self.dataset)