You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 2.4 kB

3 years ago
3 years ago
3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from functools import reduce
  2. from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
  3. from tests.helpers.datasets.normal_data import NormalSampler
  4. class Test_WrapDataLoader:
  5. def test_normal_generator(self):
  6. all_sanity_batches = [4, 20, 100]
  7. for sanity_batches in all_sanity_batches:
  8. data = NormalSampler(num_of_data=1000)
  9. wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
  10. dataloader = iter(wrapper)
  11. mark = 0
  12. while True:
  13. try:
  14. _data = next(dataloader)
  15. except StopIteration:
  16. break
  17. mark += 1
  18. assert mark == sanity_batches
  19. def test_torch_dataloader(self):
  20. from tests.helpers.datasets.torch_data import TorchNormalDataset
  21. from torch.utils.data import DataLoader
  22. bses = [8, 16, 40]
  23. all_sanity_batches = [4, 7, 10]
  24. for bs in bses:
  25. for sanity_batches in all_sanity_batches:
  26. dataset = TorchNormalDataset(num_of_data=1000)
  27. dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
  28. wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
  29. dataloader = iter(wrapper)
  30. all_supposed_running_data_num = 0
  31. while True:
  32. try:
  33. _data = next(dataloader)
  34. except StopIteration:
  35. break
  36. all_supposed_running_data_num += _data.shape[0]
  37. assert all_supposed_running_data_num == bs * sanity_batches
  38. def test_len(self):
  39. from tests.helpers.datasets.torch_data import TorchNormalDataset
  40. from torch.utils.data import DataLoader
  41. bses = [8, 16, 40]
  42. all_sanity_batches = [4, 7, 10]
  43. length = []
  44. for bs in bses:
  45. for sanity_batches in all_sanity_batches:
  46. dataset = TorchNormalDataset(num_of_data=1000)
  47. dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
  48. wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
  49. length.append(len(wrapper))
  50. assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))])