diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index 0fb877b7..361b8ae0 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -44,44 +44,40 @@ def format_list(para) -> List: return para -class MsIterableDataset(torch.utils.data.IterableDataset): +class MsMapDataset(torch.utils.data.Dataset): def __init__(self, dataset: Iterable, preprocessor_list, retained_columns, - columns): - super(MsIterableDataset).__init__() + columns, to_tensor): + super(MsDataset).__init__() self.dataset = dataset self.preprocessor_list = preprocessor_list + self.to_tensor = to_tensor self.retained_columns = retained_columns self.columns = columns def __len__(self): return len(self.dataset) - def __iter__(self): - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: # single-process data loading - iter_start = 0 - iter_end = len(self.dataset) - else: # in a worker process - per_worker = math.ceil( - len(self.dataset) / float(worker_info.num_workers)) - worker_id = worker_info.id - iter_start = worker_id * per_worker - iter_end = min(iter_start + per_worker, len(self.dataset)) - - for idx in range(iter_start, iter_end): - item_dict = self.dataset[idx] - res = { - k: torch.tensor(item_dict[k]) - for k in self.columns if k in self.retained_columns - } - for preprocessor in self.preprocessor_list: - res.update({ - k: torch.tensor(v) - for k, v in preprocessor(item_dict).items() - if k in self.retained_columns - }) - yield res + def type_converter(self, x): + if self.to_tensor: + return torch.tensor(x) + else: + return x + + def __getitem__(self, index): + item_dict = self.dataset[index] + res = { + k: self.type_converter(item_dict[k]) + for k in self.columns + if (not self.to_tensor) or k in self.retained_columns + } + for preprocessor in self.preprocessor_list: + res.update({ + k: self.type_converter(v) + for k, v in preprocessor(item_dict).items() + if (not self.to_tensor) or k in self.retained_columns + }) + return res class MsDataset: @@ -341,6 +337,7 @@ class MsDataset: self, preprocessors: Union[Callable, List[Callable]], columns: Union[str, List[str]] = None, + to_tensor: bool = True, ): preprocessor_list = preprocessors if isinstance( preprocessors, list) else [preprocessors] @@ -350,28 +347,29 @@ class MsDataset: columns = [ key for key in self._hf_ds.features.keys() if key in columns ] - sample = next(iter(self._hf_ds)) + retained_columns = [] + if to_tensor: + sample = next(iter(self._hf_ds)) - sample_res = {k: np.array(sample[k]) for k in columns} - for processor in preprocessor_list: - sample_res.update( - {k: np.array(v) - for k, v in processor(sample).items()}) + sample_res = {k: np.array(sample[k]) for k in columns} + for processor in preprocessor_list: + sample_res.update( + {k: np.array(v) + for k, v in processor(sample).items()}) - def is_numpy_number(value): - return np.issubdtype(value.dtype, np.integer) or np.issubdtype( - value.dtype, np.floating) + def is_numpy_number(value): + return np.issubdtype(value.dtype, np.integer) or np.issubdtype( + value.dtype, np.floating) - retained_columns = [] - for k in sample_res.keys(): - if not is_numpy_number(sample_res[k]): - logger.warning( - f'Data of column {k} is non-numeric, will be removed') - continue - retained_columns.append(k) + for k in sample_res.keys(): + if not is_numpy_number(sample_res[k]): + logger.warning( + f'Data of column {k} is non-numeric, will be removed') + continue + retained_columns.append(k) - return MsIterableDataset(self._hf_ds, preprocessor_list, - retained_columns, columns) + return MsMapDataset(self._hf_ds, preprocessor_list, retained_columns, + columns, to_tensor) def to_torch_dataset( self, @@ -379,6 +377,7 @@ class MsDataset: preprocessors: Union[Callable, List[Callable]] = None, task_name: str = None, task_data_config: ConfigDict = None, + to_tensor: bool = True, **format_kwargs, ): """Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to @@ -386,13 +385,14 @@ class MsDataset: Args: preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process - every sample of the dataset. The output type of processors is dict, and each numeric field of the dict + every sample of the dataset. The output type of processors is dict, and each (numeric) field of the dict will be used as a field of torch.utils.data.Dataset. - columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only). If the - preprocessor is None, the arg columns must have at least one column. If the `preprocessors` is not None, - the output fields of processors will also be added. + columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only if + `to_tensor` is True). If the preprocessor is None, the arg columns must have at least one column. + If the `preprocessors` is not None, the output fields of processors will also be added. task_name (str, default None): task name, refer to :obj:`Tasks` for more details task_data_config (ConfigDict, default None): config dict for model object. + to_tensor (bool, default None): whether convert the data types of dataset column(s) to torch.tensor or not. format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`. Returns: @@ -409,7 +409,7 @@ class MsDataset: return build_task_dataset(task_data_config, task_name) if preprocessors is not None: return self.to_torch_dataset_with_processors( - preprocessors, columns=columns) + preprocessors, columns=columns, to_tensor=to_tensor) else: self._hf_ds.reset_format() self._hf_ds.set_format(