You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paddle_data.py 766 B

123456789101112131415161718192021222324252627282930
  1. import numpy as np
  2. from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
  3. if _NEED_IMPORT_PADDLE:
  4. import paddle
  5. from paddle.io import Dataset
  6. class PaddleNormalDataset(Dataset):
  7. def __init__(self, num_of_data=1000):
  8. self.num_of_data = num_of_data
  9. self._data = list(range(num_of_data))
  10. def __len__(self):
  11. return self.num_of_data
  12. def __getitem__(self, item):
  13. return self._data[item]
  14. class PaddleRandomMaxDataset(Dataset):
  15. def __init__(self, num_samples, num_features):
  16. self.x = paddle.randn((num_samples, num_features))
  17. self.y = self.x.argmax(axis=-1)
  18. def __len__(self):
  19. return len(self.x)
  20. def __getitem__(self, item):
  21. return {"x": self.x[item], "y": self.y[item]}