diff --git a/modelscope/msdatasets/utils/dataset_builder.py b/modelscope/msdatasets/utils/dataset_builder.py index 0548f7b9..e2f51476 100644 --- a/modelscope/msdatasets/utils/dataset_builder.py +++ b/modelscope/msdatasets/utils/dataset_builder.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import copy import os from typing import Mapping, Sequence, Union @@ -8,6 +9,7 @@ import pandas as pd import pyarrow as pa from datasets.info import DatasetInfo from datasets.naming import camelcase_to_snakecase +from datasets.packaged_modules import _EXTENSION_TO_MODULE as exts from datasets.packaged_modules import csv from datasets.utils.filelock import FileLock @@ -190,8 +192,54 @@ class TaskSpecificDatasetBuilder(MsCsvDatasetBuilder): class ExternalDataset(object): def __init__(self, split_path_dict, config_kwargs): - config_kwargs.update({'split_config': split_path_dict}) - self.config_kwargs = config_kwargs + self.split_path_dict = split_path_dict + self.config_kwargs = copy.deepcopy(config_kwargs) + self.config_kwargs.update({'split_config': split_path_dict}) + self.ext_dataset = None + self.split_data_files = {k: [] for k, _ in split_path_dict.items()} + file_ext = '' + for split_name, split_dir in split_path_dict.items(): + if os.path.isdir(split_dir): + split_file_names = os.listdir(split_dir) + set_files_exts = set([ + os.path.splitext(file_name)[-1].strip('.') + for file_name in split_file_names + ]) + # ensure these files have same extensions + if len(set_files_exts) != 1: + supported_exts = ','.join(exts.keys()) + logger.error( + f'Split-{split_name} has been ignored, please flatten your folder structure, ' + f'and make sure these files have same extensions. ' + f'Supported extensions: {supported_exts} .') + continue + file_ext = list(set_files_exts)[0] + + split_file_paths = [ + os.path.join(split_dir, file_name) + for file_name in split_file_names + ] + self.split_data_files[split_name] = split_file_paths + + if file_ext and file_ext in exts: + file_ext = exts.get(file_ext) + self.ext_dataset = datasets.load_dataset( + file_ext, data_files=self.split_data_files, **config_kwargs) def __len__(self): - return len(self.config_kwargs['split_config']) + return len(self.split_path_dict + ) if not self.ext_dataset else self.ext_dataset.__len__() + + def __getitem__(self, item): + if not self.ext_dataset: + return self.split_path_dict.get(item) + else: + return self.ext_dataset.__getitem__(item) + + def __iter__(self): + if not self.ext_dataset: + for k, v in self.split_path_dict.items(): + yield k, v + else: + for k, v in self.ext_dataset.items(): + yield k, v diff --git a/modelscope/msdatasets/utils/dataset_utils.py b/modelscope/msdatasets/utils/dataset_utils.py index 7a46b325..b4c9c177 100644 --- a/modelscope/msdatasets/utils/dataset_utils.py +++ b/modelscope/msdatasets/utils/dataset_utils.py @@ -222,7 +222,8 @@ def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str, subset_name=subset_name, meta_data_files=meta_data_files, zip_data_files=zip_data_files, - hash=sub_dir) + hash=sub_dir, + **config_kwargs) else: raise NotImplementedError( f'Dataset mete file extensions "{os.path.splitext(meta_data_file)[-1]}" is not implemented yet'