From c7238a470bb666ac147571d4a2b28a5d2ac09d35 Mon Sep 17 00:00:00 2001 From: "feiwu.yfw" Date: Tue, 21 Jun 2022 11:10:28 +0800 Subject: [PATCH] [to #42670107]pydataset fetch data from datahub * pydataset fetch data from datahub Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9060856 --- modelscope/preprocessors/nlp.py | 39 +- modelscope/pydatasets/config.py | 22 ++ modelscope/pydatasets/py_dataset.py | 381 +++++++++++++++++--- modelscope/pydatasets/utils/__init__.py | 0 modelscope/pydatasets/utils/ms_api.py | 66 ++++ modelscope/utils/test_utils.py | 15 + tests/pipelines/test_image_matting.py | 11 + tests/pipelines/test_text_classification.py | 28 +- tests/pydatasets/test_py_dataset.py | 121 +++++-- 9 files changed, 580 insertions(+), 103 deletions(-) create mode 100644 modelscope/pydatasets/config.py create mode 100644 modelscope/pydatasets/utils/__init__.py create mode 100644 modelscope/pydatasets/utils/ms_api.py diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 9bcaa87c..0abb01cc 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -53,12 +53,12 @@ class SequenceClassificationPreprocessor(Preprocessor): self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) print(f'this is the tokenzier {self.tokenizer}') - @type_assert(object, (str, tuple)) - def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: + @type_assert(object, (str, tuple, Dict)) + def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]: """process the raw input data Args: - data (str or tuple): + data (str or tuple, Dict): sentence1 (str): a sentence Example: 'you are so handsome.' @@ -70,22 +70,31 @@ class SequenceClassificationPreprocessor(Preprocessor): sentence2 (str): a sentence Example: 'you are so beautiful.' + or + {field1: field_value1, field2: field_value2} + field1 (str): field name, default 'first_sequence' + field_value1 (str): a sentence + Example: + 'you are so handsome.' + + field2 (str): field name, default 'second_sequence' + field_value2 (str): a sentence + Example: + 'you are so beautiful.' Returns: Dict[str, Any]: the preprocessed data """ - - if not isinstance(data, tuple): - data = ( - data, - None, - ) - - sentence1, sentence2 = data - new_data = { - self.first_sequence: sentence1, - self.second_sequence: sentence2 - } + if isinstance(data, str): + new_data = {self.first_sequence: data} + elif isinstance(data, tuple): + sentence1, sentence2 = data + new_data = { + self.first_sequence: sentence1, + self.second_sequence: sentence2 + } + else: + new_data = data # preprocess the data for the model input diff --git a/modelscope/pydatasets/config.py b/modelscope/pydatasets/config.py new file mode 100644 index 00000000..e916b3ec --- /dev/null +++ b/modelscope/pydatasets/config.py @@ -0,0 +1,22 @@ +import os +from pathlib import Path + +# Cache location +DEFAULT_CACHE_HOME = '~/.cache' +CACHE_HOME = os.getenv('CACHE_HOME', DEFAULT_CACHE_HOME) +DEFAULT_MS_CACHE_HOME = os.path.join(CACHE_HOME, 'modelscope/hub') +MS_CACHE_HOME = os.path.expanduser( + os.getenv('MS_CACHE_HOME', DEFAULT_MS_CACHE_HOME)) + +DEFAULT_MS_DATASETS_CACHE = os.path.join(MS_CACHE_HOME, 'datasets') +MS_DATASETS_CACHE = Path( + os.getenv('MS_DATASETS_CACHE', DEFAULT_MS_DATASETS_CACHE)) + +DOWNLOADED_DATASETS_DIR = 'downloads' +DEFAULT_DOWNLOADED_DATASETS_PATH = os.path.join(MS_DATASETS_CACHE, + DOWNLOADED_DATASETS_DIR) +DOWNLOADED_DATASETS_PATH = Path( + os.getenv('DOWNLOADED_DATASETS_PATH', DEFAULT_DOWNLOADED_DATASETS_PATH)) + +MS_HUB_ENDPOINT = os.environ.get('MS_HUB_ENDPOINT', + 'http://101.201.119.157:31752') diff --git a/modelscope/pydatasets/py_dataset.py b/modelscope/pydatasets/py_dataset.py index 78aedaa0..49137253 100644 --- a/modelscope/pydatasets/py_dataset.py +++ b/modelscope/pydatasets/py_dataset.py @@ -1,64 +1,81 @@ -from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, - Union) +import os +from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional, + Sequence, Union) -from datasets import Dataset, load_dataset +import numpy as np +from datasets import Dataset +from datasets import load_dataset as hf_load_dataset +from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE +from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES +from datasets.utils.file_utils import (is_relative_path, + relative_to_absolute_path) +from modelscope.pydatasets.config import MS_DATASETS_CACHE +from modelscope.pydatasets.utils.ms_api import MsApi from modelscope.utils.constant import Hubs from modelscope.utils.logger import get_logger logger = get_logger() +def format_list(para) -> List: + if para is None: + para = [] + elif isinstance(para, str): + para = [para] + elif len(set(para)) < len(para): + raise ValueError(f'List columns contains duplicates: {para}') + return para + + class PyDataset: _hf_ds = None # holds the underlying HuggingFace Dataset """A PyDataset backed by hugging face Dataset.""" - def __init__(self, hf_ds: Dataset): + def __init__(self, hf_ds: Dataset, target: Optional[str] = None): self._hf_ds = hf_ds - self.target = None + self.target = target def __iter__(self): - if isinstance(self._hf_ds, Dataset): - for item in self._hf_ds: - if self.target is not None: - yield item[self.target] - else: - yield item - else: - for ds in self._hf_ds.values(): - for item in ds: - if self.target is not None: - yield item[self.target] - else: - yield item + for item in self._hf_ds: + if self.target is not None: + yield item[self.target] + else: + yield item + + def __getitem__(self, key): + return self._hf_ds[key] @classmethod def from_hf_dataset(cls, hf_ds: Dataset, - target: str = None) -> 'PyDataset': - dataset = cls(hf_ds) - dataset.target = target - return dataset + target: str = None) -> Union[dict, 'PyDataset']: + if isinstance(hf_ds, Dataset): + return cls(hf_ds, target) + if len(hf_ds.keys()) == 1: + return cls(next(iter(hf_ds.values())), target) + return {k: cls(v, target) for k, v in hf_ds.items()} @staticmethod - def load(path: Union[str, list], - target: Optional[str] = None, - version: Optional[str] = None, - name: Optional[str] = None, - split: Optional[str] = None, - data_dir: Optional[str] = None, - data_files: Optional[Union[str, Sequence[str], - Mapping[str, - Union[str, - Sequence[str]]]]] = None, - hub: Optional[Hubs] = None) -> 'PyDataset': + def load( + dataset_name: Union[str, list], + target: Optional[str] = None, + version: Optional[str] = None, + hub: Optional[Hubs] = Hubs.modelscope, + subset_name: Optional[str] = None, + split: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[Union[str, Sequence[str], + Mapping[str, Union[str, + Sequence[str]]]]] = None + ) -> Union[dict, 'PyDataset']: """Load a PyDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset. Args: - path (str): Path or name of the dataset. + dataset_name (str): Path or name of the dataset. target (str, optional): Name of the column to output. version (str, optional): Version of the dataset script to load: - name (str, optional): Defining the subset_name of the dataset. + subset_name (str, optional): Defining the subset_name of the dataset. data_dir (str, optional): Defining the data_dir of the dataset configuration. I data_files (str or Sequence or Mapping, optional): Path(s) to source data file(s). split (str, optional): Which split of the data to load. @@ -67,53 +84,302 @@ class PyDataset: Returns: PyDataset (obj:`PyDataset`): PyDataset object for a certain dataset. """ - if Hubs.modelscope == hub: - # TODO: parse data meta information from modelscope hub - # and possibly download data files to local (and update path) - print('getting data from modelscope hub') - if isinstance(path, str): - dataset = load_dataset( - path, - name=name, + if hub == Hubs.huggingface: + dataset = hf_load_dataset( + dataset_name, + name=subset_name, revision=version, split=split, data_dir=data_dir, data_files=data_files) - elif isinstance(path, list): + return PyDataset.from_hf_dataset(dataset, target=target) + else: + return PyDataset._load_ms_dataset( + dataset_name, + target=target, + subset_name=subset_name, + version=version, + split=split, + data_dir=data_dir, + data_files=data_files) + + @staticmethod + def _load_ms_dataset( + dataset_name: Union[str, list], + target: Optional[str] = None, + version: Optional[str] = None, + subset_name: Optional[str] = None, + split: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[Union[str, Sequence[str], + Mapping[str, Union[str, + Sequence[str]]]]] = None + ) -> Union[dict, 'PyDataset']: + if isinstance(dataset_name, str): + use_hf = False + if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir(dataset_name) or \ + (os.path.isfile(dataset_name) and dataset_name.endswith('.py')): + use_hf = True + elif is_relative_path(dataset_name): + ms_api = MsApi() + dataset_scripts = ms_api.fetch_dataset_scripts( + dataset_name, version) + if 'py' in dataset_scripts: # dataset copied from hf datasets + dataset_name = dataset_scripts['py'][0] + use_hf = True + else: + raise FileNotFoundError( + f"Couldn't find a dataset script at {relative_to_absolute_path(dataset_name)} " + f'or any data file in the same directory.') + + if use_hf: + dataset = hf_load_dataset( + dataset_name, + name=subset_name, + revision=version, + split=split, + data_dir=data_dir, + data_files=data_files, + cache_dir=MS_DATASETS_CACHE) + else: + # TODO load from ms datahub + raise NotImplementedError( + f'Dataset {dataset_name} load from modelscope datahub to be implemented in ' + f'the future') + elif isinstance(dataset_name, list): if target is None: target = 'target' - dataset = Dataset.from_dict({target: [p] for p in path}) + dataset = Dataset.from_dict({target: dataset_name}) else: raise TypeError('path must be a str or a list, but got' - f' {type(path)}') + f' {type(dataset_name)}') return PyDataset.from_hf_dataset(dataset, target=target) + def to_torch_dataset_with_processors( + self, + preprocessors: Union[Callable, List[Callable]], + columns: Union[str, List[str]] = None, + ): + preprocessor_list = preprocessors if isinstance( + preprocessors, list) else [preprocessors] + + columns = format_list(columns) + + columns = [ + key for key in self._hf_ds.features.keys() if key in columns + ] + sample = next(iter(self._hf_ds)) + + sample_res = {k: np.array(sample[k]) for k in columns} + for processor in preprocessor_list: + sample_res.update( + {k: np.array(v) + for k, v in processor(sample).items()}) + + def is_numpy_number(value): + return np.issubdtype(value.dtype, np.integer) or np.issubdtype( + value.dtype, np.floating) + + retained_columns = [] + for k in sample_res.keys(): + if not is_numpy_number(sample_res[k]): + logger.warning( + f'Data of column {k} is non-numeric, will be removed') + continue + retained_columns.append(k) + + import torch + + class MsIterableDataset(torch.utils.data.IterableDataset): + + def __init__(self, dataset: Iterable): + super(MsIterableDataset).__init__() + self.dataset = dataset + + def __iter__(self): + for item_dict in self.dataset: + res = { + k: np.array(item_dict[k]) + for k in columns if k in retained_columns + } + for preprocessor in preprocessor_list: + res.update({ + k: np.array(v) + for k, v in preprocessor(item_dict).items() + if k in retained_columns + }) + yield res + + return MsIterableDataset(self._hf_ds) + def to_torch_dataset( self, columns: Union[str, List[str]] = None, - output_all_columns: bool = False, + preprocessors: Union[Callable, List[Callable]] = None, **format_kwargs, ): - self._hf_ds.reset_format() - self._hf_ds.set_format( - type='torch', - columns=columns, - output_all_columns=output_all_columns, - format_kwargs=format_kwargs) - return self._hf_ds + """Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to + torch.utils.data.DataLoader. + + Args: + preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process + every sample of the dataset. The output type of processors is dict, and each numeric field of the dict + will be used as a field of torch.utils.data.Dataset. + columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only). If the + preprocessor is None, the arg columns must have at least one column. If the `preprocessors` is not None, + the output fields of processors will also be added. + format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`. + + Returns: + :class:`tf.data.Dataset` + + """ + if not TORCH_AVAILABLE: + raise ImportError( + 'The function to_torch_dataset requires pytorch to be installed' + ) + if preprocessors is not None: + return self.to_torch_dataset_with_processors(preprocessors) + else: + self._hf_ds.reset_format() + self._hf_ds.set_format( + type='torch', columns=columns, format_kwargs=format_kwargs) + return self._hf_ds + + def to_tf_dataset_with_processors( + self, + batch_size: int, + shuffle: bool, + preprocessors: Union[Callable, List[Callable]], + drop_remainder: bool = None, + prefetch: bool = True, + label_cols: Union[str, List[str]] = None, + columns: Union[str, List[str]] = None, + ): + preprocessor_list = preprocessors if isinstance( + preprocessors, list) else [preprocessors] + + label_cols = format_list(label_cols) + columns = format_list(columns) + cols_to_retain = list(set(label_cols + columns)) + retained_columns = [ + key for key in self._hf_ds.features.keys() if key in cols_to_retain + ] + import tensorflow as tf + tf_dataset = tf.data.Dataset.from_tensor_slices( + np.arange(len(self._hf_ds), dtype=np.int64)) + if shuffle: + tf_dataset = tf_dataset.shuffle(buffer_size=len(self._hf_ds)) + + def func(i, return_dict=False): + i = int(i) + res = {k: np.array(self._hf_ds[i][k]) for k in retained_columns} + for preprocessor in preprocessor_list: + # TODO preprocessor output may have the same key + res.update({ + k: np.array(v) + for k, v in preprocessor(self._hf_ds[i]).items() + }) + if return_dict: + return res + return tuple(list(res.values())) + + sample_res = func(0, True) + + @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)]) + def fetch_function(i): + output = tf.numpy_function( + func, + inp=[i], + Tout=[ + tf.dtypes.as_dtype(val.dtype) + for val in sample_res.values() + ], + ) + return {key: output[i] for i, key in enumerate(sample_res)} + + tf_dataset = tf_dataset.map( + fetch_function, num_parallel_calls=tf.data.AUTOTUNE) + if label_cols: + + def split_features_and_labels(input_batch): + labels = { + key: tensor + for key, tensor in input_batch.items() if key in label_cols + } + if len(input_batch) == 1: + input_batch = next(iter(input_batch.values())) + if len(labels) == 1: + labels = next(iter(labels.values())) + return input_batch, labels + + tf_dataset = tf_dataset.map(split_features_and_labels) + + elif len(columns) == 1: + tf_dataset = tf_dataset.map(lambda x: next(iter(x.values()))) + if batch_size > 1: + tf_dataset = tf_dataset.batch( + batch_size, drop_remainder=drop_remainder) + + if prefetch: + tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE) + return tf_dataset def to_tf_dataset( self, - columns: Union[str, List[str]], batch_size: int, shuffle: bool, - collate_fn: Callable, + preprocessors: Union[Callable, List[Callable]] = None, + columns: Union[str, List[str]] = None, + collate_fn: Callable = None, drop_remainder: bool = None, collate_fn_args: Dict[str, Any] = None, label_cols: Union[str, List[str]] = None, - dummy_labels: bool = False, prefetch: bool = True, ): + """Create a tf.data.Dataset from the MS Dataset. This tf.data.Dataset can be passed to tf methods like + model.fit() or model.predict(). + + Args: + batch_size (int): Number of samples in a single batch. + shuffle(bool): Shuffle the dataset order. + preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process + every sample of the dataset. The output type of processors is dict, and each field of the dict will be + used as a field of the tf.data. Dataset. If the `preprocessors` is None, the `collate_fn` + shouldn't be None. + columns (str or List[str], default None): Dataset column(s) to be loaded. If the preprocessor is None, + the arg columns must have at least one column. If the `preprocessors` is not None, the output fields of + processors will also be added. + collate_fn(Callable, default None): A callable object used to collect lists of samples into a batch. If + the `preprocessors` is None, the `collate_fn` shouldn't be None. + drop_remainder(bool, default None): Drop the last incomplete batch when loading. + collate_fn_args (Dict, optional): A `dict` of arguments to be passed to the`collate_fn`. + label_cols (str or List[str], defalut None): Dataset column(s) to load as labels. + prefetch (bool, default True): Prefetch data. + + Returns: + :class:`tf.data.Dataset` + + """ + if not TF_AVAILABLE: + raise ImportError( + 'The function to_tf_dataset requires Tensorflow to be installed.' + ) + if preprocessors is not None: + return self.to_tf_dataset_with_processors( + batch_size, + shuffle, + preprocessors, + drop_remainder=drop_remainder, + prefetch=prefetch, + label_cols=label_cols, + columns=columns) + + if collate_fn is None: + logger.error( + 'The `preprocessors` and the `collate_fn` should`t be both None.' + ) + return None self._hf_ds.reset_format() return self._hf_ds.to_tf_dataset( columns, @@ -123,7 +389,6 @@ class PyDataset: drop_remainder=drop_remainder, collate_fn_args=collate_fn_args, label_cols=label_cols, - dummy_labels=dummy_labels, prefetch=prefetch) def to_hf_dataset(self) -> Dataset: diff --git a/modelscope/pydatasets/utils/__init__.py b/modelscope/pydatasets/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pydatasets/utils/ms_api.py b/modelscope/pydatasets/utils/ms_api.py new file mode 100644 index 00000000..04052cc4 --- /dev/null +++ b/modelscope/pydatasets/utils/ms_api.py @@ -0,0 +1,66 @@ +import os +from collections import defaultdict +from typing import Optional + +import requests + +from modelscope.pydatasets.config import (DOWNLOADED_DATASETS_PATH, + MS_HUB_ENDPOINT) +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class MsApi: + + def __init__(self, endpoint=MS_HUB_ENDPOINT): + self.endpoint = endpoint + + def list_datasets(self): + path = f'{self.endpoint}/api/v1/datasets' + headers = None + params = {} + r = requests.get(path, params=params, headers=headers) + r.raise_for_status() + dataset_list = r.json()['Data'] + return [x['Name'] for x in dataset_list] + + def fetch_dataset_scripts(self, + dataset_name: str, + version: Optional[str] = 'master', + force_download=False): + datahub_url = f'{self.endpoint}/api/v1/datasets?Query={dataset_name}' + r = requests.get(datahub_url) + r.raise_for_status() + dataset_list = r.json()['Data'] + if len(dataset_list) == 0: + return None + dataset_id = dataset_list[0]['Id'] + version = version or 'master' + datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={version}' + r = requests.get(datahub_url) + r.raise_for_status() + file_list = r.json()['Data']['Files'] + cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, dataset_name, + version) + os.makedirs(cache_dir, exist_ok=True) + local_paths = defaultdict(list) + for file_info in file_list: + file_path = file_info['Path'] + if file_path.endswith('.py'): + datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/files?' \ + f'Revision={version}&Path={file_path}' + r = requests.get(datahub_url) + r.raise_for_status() + content = r.json()['Data']['Content'] + local_path = os.path.join(cache_dir, file_path) + if os.path.exists(local_path) and not force_download: + logger.warning( + f"Reusing dataset {dataset_name}'s python file ({local_path})" + ) + local_paths['py'].append(local_path) + continue + with open(local_path, 'w') as f: + f.writelines(content) + local_paths['py'].append(local_path) + return local_paths diff --git a/modelscope/utils/test_utils.py b/modelscope/utils/test_utils.py index c8ea0442..95e63dba 100644 --- a/modelscope/utils/test_utils.py +++ b/modelscope/utils/test_utils.py @@ -2,6 +2,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import unittest + +from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE TEST_LEVEL = 2 TEST_LEVEL_STR = 'TEST_LEVEL' @@ -15,6 +18,18 @@ def test_level(): return TEST_LEVEL +def require_tf(test_case): + if not TF_AVAILABLE: + test_case = unittest.skip('test requires TensorFlow')(test_case) + return test_case + + +def require_torch(test_case): + if not TORCH_AVAILABLE: + test_case = unittest.skip('test requires PyTorch')(test_case) + return test_case + + def set_test_level(level: int): global TEST_LEVEL TEST_LEVEL = level diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 6e102d00..e557ba86 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -66,6 +66,17 @@ class ImageMattingTest(unittest.TestCase): cv2.imwrite('result.png', result['output_png']) print(f'Output written to {osp.abspath("result.png")}') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_modelscope_dataset(self): + dataset = PyDataset.load('beans', split='train', target='image') + img_matting = pipeline(Tasks.image_matting, model=self.model_id) + result = img_matting(dataset) + for i in range(10): + cv2.imwrite(f'result_{i}.png', next(result)['output_png']) + print( + f'Output written to dir: {osp.dirname(osp.abspath("result_0.png"))}' + ) + if __name__ == '__main__': unittest.main() diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index 01fdd29b..bb24fece 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -86,7 +86,11 @@ class SequenceClassificationTest(unittest.TestCase): task=Tasks.text_classification, model=self.model_id) result = text_classification( PyDataset.load( - 'glue', name='sst2', target='sentence', hub=Hubs.huggingface)) + 'glue', + subset_name='sst2', + split='train', + target='sentence', + hub=Hubs.huggingface)) self.printDataset(result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -94,7 +98,11 @@ class SequenceClassificationTest(unittest.TestCase): text_classification = pipeline(task=Tasks.text_classification) result = text_classification( PyDataset.load( - 'glue', name='sst2', target='sentence', hub=Hubs.huggingface)) + 'glue', + subset_name='sst2', + split='train', + target='sentence', + hub=Hubs.huggingface)) self.printDataset(result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -105,9 +113,21 @@ class SequenceClassificationTest(unittest.TestCase): text_classification = pipeline( Tasks.text_classification, model=model, preprocessor=preprocessor) # loaded from huggingface dataset - # TODO: rename parameter as dataset_name and subset_name dataset = PyDataset.load( - 'glue', name='sst2', target='sentence', hub=Hubs.huggingface) + 'glue', + subset_name='sst2', + split='train', + target='sentence', + hub=Hubs.huggingface) + result = text_classification(dataset) + self.printDataset(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_modelscope_dataset(self): + text_classification = pipeline(task=Tasks.text_classification) + # loaded from modelscope dataset + dataset = PyDataset.load( + 'squad', split='train', target='context', hub=Hubs.modelscope) result = text_classification(dataset) self.printDataset(result) diff --git a/tests/pydatasets/test_py_dataset.py b/tests/pydatasets/test_py_dataset.py index 7accd814..4ad767fa 100644 --- a/tests/pydatasets/test_py_dataset.py +++ b/tests/pydatasets/test_py_dataset.py @@ -2,42 +2,111 @@ import unittest import datasets as hfdata +from modelscope.models import Model +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.preprocessors.base import Preprocessor from modelscope.pydatasets import PyDataset +from modelscope.utils.constant import Hubs +from modelscope.utils.test_utils import require_tf, require_torch, test_level -class PyDatasetTest(unittest.TestCase): +class ImgPreprocessor(Preprocessor): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.path_field = kwargs.pop('image_path', 'image_path') + self.width = kwargs.pop('width', 'width') + self.height = kwargs.pop('height', 'width') - def setUp(self): - # ds1 initialized from in memory json - self.json_data = { - 'dummy': [{ - 'a': i, - 'x': i * 10, - 'c': i * 100 - } for i in range(1, 11)] + def __call__(self, data): + import cv2 + image_path = data.get(self.path_field) + if not image_path: + return None + img = cv2.imread(image_path) + return { + 'image': + cv2.resize(img, + (data.get(self.height, 128), data.get(self.width, 128))) } - hfds1 = hfdata.Dataset.from_dict(self.json_data) - self.ds1 = PyDataset.from_hf_dataset(hfds1) - # ds2 initialized from hg hub - hfds2 = hfdata.load_dataset( - 'glue', 'mrpc', revision='2.0.0', split='train') - self.ds2 = PyDataset.from_hf_dataset(hfds2) - def tearDown(self): - pass +class PyDatasetTest(unittest.TestCase): + + def test_ds_basic(self): + ms_ds_full = PyDataset.load('squad') + ms_ds_full_hf = hfdata.load_dataset('squad') + ms_ds_train = PyDataset.load('squad', split='train') + ms_ds_train_hf = hfdata.load_dataset('squad', split='train') + ms_image_train = PyDataset.from_hf_dataset( + hfdata.load_dataset('beans', split='train')) + self.assertEqual(ms_ds_full['train'][0], ms_ds_full_hf['train'][0]) + self.assertEqual(ms_ds_full['validation'][0], + ms_ds_full_hf['validation'][0]) + self.assertEqual(ms_ds_train[0], ms_ds_train_hf[0]) + print(next(iter(ms_ds_full['train']))) + print(next(iter(ms_ds_train))) + print(next(iter(ms_image_train))) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @require_torch + def test_to_torch_dataset_text(self): + model_id = 'damo/bert-base-sst2' + nlp_model = Model.from_pretrained(model_id) + preprocessor = SequenceClassificationPreprocessor( + nlp_model.model_dir, + first_sequence='context', + second_sequence=None) + ms_ds_train = PyDataset.load('squad', split='train') + pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor) + import torch + dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5) + print(next(iter(dataloader))) - def test_to_hf_dataset(self): - hfds = self.ds1.to_hf_dataset() - hfds1 = hfdata.Dataset.from_dict(self.json_data) - self.assertEqual(hfds.data, hfds1.data) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @require_tf + def test_to_tf_dataset_text(self): + import tensorflow as tf + tf.compat.v1.enable_eager_execution() + model_id = 'damo/bert-base-sst2' + nlp_model = Model.from_pretrained(model_id) + preprocessor = SequenceClassificationPreprocessor( + nlp_model.model_dir, + first_sequence='context', + second_sequence=None) + ms_ds_train = PyDataset.load('squad', split='train') + tf_dataset = ms_ds_train.to_tf_dataset( + batch_size=5, + shuffle=True, + preprocessors=preprocessor, + drop_remainder=True) + print(next(iter(tf_dataset))) - # simple map function - hfds = hfds.map(lambda e: {'new_feature': e['dummy']['a']}) - self.assertEqual(len(hfds['new_feature']), 10) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @require_torch + def test_to_torch_dataset_img(self): + ms_image_train = PyDataset.from_hf_dataset( + hfdata.load_dataset('beans', split='train')) + pt_dataset = ms_image_train.to_torch_dataset( + preprocessors=ImgPreprocessor( + image_path='image_file_path', label='labels')) + import torch + dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5) + print(next(iter(dataloader))) - hfds2 = self.ds2.to_hf_dataset() - self.assertTrue(hfds2[0]['sentence1'].startswith('Amrozi')) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @require_tf + def test_to_tf_dataset_img(self): + import tensorflow as tf + tf.compat.v1.enable_eager_execution() + ms_image_train = PyDataset.load('beans', split='train') + tf_dataset = ms_image_train.to_tf_dataset( + batch_size=5, + shuffle=True, + preprocessors=ImgPreprocessor(image_path='image_file_path'), + drop_remainder=True, + label_cols='labels') + print(next(iter(tf_dataset))) if __name__ == '__main__':