You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paddle_data.py 841 B

1234567891011121314151617181920212223242526272829303132
  1. import numpy as np
  2. from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
  3. if _NEED_IMPORT_PADDLE:
  4. import paddle
  5. from paddle.io import Dataset
  6. else:
  7. from fastNLP.core.utils.dummy_class import DummyClass as Dataset
  8. class PaddleNormalDataset(Dataset):
  9. def __init__(self, num_of_data=1000):
  10. self.num_of_data = num_of_data
  11. self._data = list(range(num_of_data))
  12. def __len__(self):
  13. return self.num_of_data
  14. def __getitem__(self, item):
  15. return self._data[item]
  16. class PaddleRandomMaxDataset(Dataset):
  17. def __init__(self, num_samples, num_features):
  18. self.x = paddle.randn((num_samples, num_features))
  19. self.y = self.x.argmax(axis=-1)
  20. def __len__(self):
  21. return len(self.x)
  22. def __getitem__(self, item):
  23. return {"x": self.x[item], "y": self.y[item]}