# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import copy import logging import random import torch.utils.data as data from detectron2.utils.serialize import PicklableWrapper __all__ = ["MapDataset", "DatasetFromList"] class MapDataset(data.Dataset): """ Map a function over the elements in a dataset. Args: dataset: a dataset where map function is applied. map_func: a callable which maps the element in dataset. map_func is responsible for error handling, when error happens, it needs to return None so the MapDataset will randomly use other elements from the dataset. """ def __init__(self, dataset, map_func): self._dataset = dataset self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work self._rng = random.Random(42) self._fallback_candidates = set(range(len(dataset))) def __len__(self): return len(self._dataset) def __getitem__(self, idx): retry_count = 0 cur_idx = int(idx) while True: data = self._map_func(self._dataset[cur_idx]) if data is not None: self._fallback_candidates.add(cur_idx) return data # _map_func fails for this idx, use a random new index from the pool retry_count += 1 self._fallback_candidates.discard(cur_idx) cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0] if retry_count >= 3: logger = logging.getLogger(__name__) logger.warning( "Failed to apply `_map_func` for idx: {}, retry count: {}".format( idx, retry_count ) ) class DatasetFromList(data.Dataset): """ Wrap a list to a torch Dataset. It produces elements of the list as data. """ def __init__(self, lst: list, copy: bool = True): """ Args: lst (list): a list which contains elements to produce. copy (bool): whether to deepcopy the element when producing it, so that the result can be modified in place without affecting the source in the list. """ self._lst = lst self._copy = copy def __len__(self): return len(self._lst) def __getitem__(self, idx): if self._copy: return copy.deepcopy(self._lst[idx]) else: return self._lst[idx]