You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.py 2.5 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import copy
  3. import logging
  4. import random
  5. import torch.utils.data as data
  6. from detectron2.utils.serialize import PicklableWrapper
  7. __all__ = ["MapDataset", "DatasetFromList"]
  8. class MapDataset(data.Dataset):
  9. """
  10. Map a function over the elements in a dataset.
  11. Args:
  12. dataset: a dataset where map function is applied.
  13. map_func: a callable which maps the element in dataset. map_func is
  14. responsible for error handling, when error happens, it needs to
  15. return None so the MapDataset will randomly use other
  16. elements from the dataset.
  17. """
  18. def __init__(self, dataset, map_func):
  19. self._dataset = dataset
  20. self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
  21. self._rng = random.Random(42)
  22. self._fallback_candidates = set(range(len(dataset)))
  23. def __len__(self):
  24. return len(self._dataset)
  25. def __getitem__(self, idx):
  26. retry_count = 0
  27. cur_idx = int(idx)
  28. while True:
  29. data = self._map_func(self._dataset[cur_idx])
  30. if data is not None:
  31. self._fallback_candidates.add(cur_idx)
  32. return data
  33. # _map_func fails for this idx, use a random new index from the pool
  34. retry_count += 1
  35. self._fallback_candidates.discard(cur_idx)
  36. cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]
  37. if retry_count >= 3:
  38. logger = logging.getLogger(__name__)
  39. logger.warning(
  40. "Failed to apply `_map_func` for idx: {}, retry count: {}".format(
  41. idx, retry_count
  42. )
  43. )
  44. class DatasetFromList(data.Dataset):
  45. """
  46. Wrap a list to a torch Dataset. It produces elements of the list as data.
  47. """
  48. def __init__(self, lst: list, copy: bool = True):
  49. """
  50. Args:
  51. lst (list): a list which contains elements to produce.
  52. copy (bool): whether to deepcopy the element when producing it,
  53. so that the result can be modified in place without affecting the
  54. source in the list.
  55. """
  56. self._lst = lst
  57. self._copy = copy
  58. def __len__(self):
  59. return len(self._lst)
  60. def __getitem__(self, idx):
  61. if self._copy:
  62. return copy.deepcopy(self._lst[idx])
  63. else:
  64. return self._lst[idx]

No Description