From 743e8769817fc802c2a9c72d7429c2bdd2250001 Mon Sep 17 00:00:00 2001 From: "feiwu.yfw" Date: Fri, 29 Jul 2022 12:22:48 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#43660556]=20msdataset=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E5=8A=A0=E8=BD=BD=20=20=20=20=20=20=20=20=20Link:=20h?= =?UTF-8?q?ttps://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/955263?= =?UTF-8?q?2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * load csv dataset from modelscoop --- modelscope/hub/api.py | 55 +++++++-- modelscope/msdatasets/ms_dataset.py | 99 ++++++++++++--- modelscope/msdatasets/utils/__init__.py | 0 .../msdatasets/utils/dataset_builder.py | 113 ++++++++++++++++++ modelscope/msdatasets/utils/dataset_utils.py | 113 ++++++++++++++++++ modelscope/msdatasets/utils/download_utils.py | 41 +++++++ modelscope/msdatasets/utils/oss_utils.py | 37 ++++++ modelscope/utils/constant.py | 17 +++ requirements/runtime.txt | 1 + tests/msdatasets/test_ms_dataset.py | 8 +- .../trainers/test_text_generation_trainer.py | 4 +- tests/trainers/test_trainer_with_nlp.py | 1 + 12 files changed, 458 insertions(+), 31 deletions(-) create mode 100644 modelscope/msdatasets/utils/__init__.py create mode 100644 modelscope/msdatasets/utils/dataset_builder.py create mode 100644 modelscope/msdatasets/utils/dataset_utils.py create mode 100644 modelscope/msdatasets/utils/download_utils.py create mode 100644 modelscope/msdatasets/utils/oss_utils.py diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 1998858c..824fe58e 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -12,7 +12,9 @@ import requests from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH, HUB_DATASET_ENDPOINT) from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, - DEFAULT_MODEL_REVISION, DownloadMode) + DEFAULT_MODEL_REVISION, + DatasetFormations, DatasetMetaFormats, + DownloadMode) from modelscope.utils.logger import get_logger from .errors import (InvalidParameter, NotExistError, RequestError, datahub_raise_on_error, handle_http_response, is_ok, @@ -301,8 +303,8 @@ class HubApi: f'Dataset from Hubs.modelscope should have a valid "namespace", but get {namespace}' ) revision = revision or DEFAULT_DATASET_REVISION - cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, dataset_name, - namespace, revision) + cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, namespace, + dataset_name, revision) download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists( @@ -314,6 +316,7 @@ class HubApi: resp = r.json() datahub_raise_on_error(datahub_url, resp) dataset_id = resp['Data']['Id'] + dataset_type = resp['Data']['Type'] datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' r = requests.get(datahub_url) resp = r.json() @@ -326,25 +329,53 @@ class HubApi: file_list = file_list['Files'] local_paths = defaultdict(list) + dataset_formation = DatasetFormations(dataset_type) + dataset_meta_format = DatasetMetaFormats[dataset_formation] for file_info in file_list: file_path = file_info['Path'] - if file_path.endswith('.py'): - datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/files?' \ - f'Revision={revision}&Path={file_path}' + extension = os.path.splitext(file_path)[-1] + if extension in dataset_meta_format: + datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ + f'Revision={revision}&FilePath={file_path}' r = requests.get(datahub_url) r.raise_for_status() - content = r.json()['Data']['Content'] local_path = os.path.join(cache_dir, file_path) if os.path.exists(local_path): logger.warning( f"Reusing dataset {dataset_name}'s python file ({local_path})" ) - local_paths['py'].append(local_path) + local_paths[extension].append(local_path) continue - with open(local_path, 'w') as f: - f.writelines(content) - local_paths['py'].append(local_path) - return local_paths + with open(local_path, 'wb') as f: + f.write(r.content) + local_paths[extension].append(local_path) + + return local_paths, dataset_formation, cache_dir + + def get_dataset_file_url( + self, + file_name: str, + dataset_name: str, + namespace: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION): + return f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ + f'Revision={revision}&FilePath={file_name}' + + def get_dataset_access_config( + self, + dataset_name: str, + namespace: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION): + datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ + f'ststoken?Revision={revision}' + return self.datahub_remote_call(datahub_url) + + @staticmethod + def datahub_remote_call(url): + r = requests.get(url) + resp = r.json() + datahub_raise_on_error(url, resp) + return resp['Data'] class ModelScopeConfig: diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index efe624cb..8174d054 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -2,17 +2,24 @@ import os from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Union) +import json import numpy as np from datasets import Dataset, DatasetDict from datasets import load_dataset as hf_load_dataset from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES +from datasets.utils.download_manager import DownloadConfig from datasets.utils.file_utils import (is_relative_path, relative_to_absolute_path) from modelscope.msdatasets.config import MS_DATASETS_CACHE -from modelscope.utils.constant import DownloadMode, Hubs +from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, + DatasetFormations, DownloadMode, Hubs) from modelscope.utils.logger import get_logger +from .utils.dataset_utils import (get_dataset_files, + get_target_dataset_structure, + load_dataset_builder) +from .utils.download_utils import DatasetDownloadManager logger = get_logger() @@ -80,7 +87,7 @@ class MsDataset: dataset_name: Union[str, list], namespace: Optional[str] = None, target: Optional[str] = None, - version: Optional[str] = None, + version: Optional[str] = DEFAULT_DATASET_REVISION, hub: Optional[Hubs] = Hubs.modelscope, subset_name: Optional[str] = None, split: Optional[str] = None, @@ -95,7 +102,7 @@ class MsDataset: Args: dataset_name (str): Path or name of the dataset. - namespace(str, optional): Namespace of the dataset. It should not be None, if you load a remote dataset + namespace(str, optional): Namespace of the dataset. It should not be None if you load a remote dataset from Hubs.modelscope, target (str, optional): Name of the column to output. version (str, optional): Version of the dataset script to load: @@ -140,7 +147,7 @@ class MsDataset: dataset_name: Union[str, list], namespace: Optional[str] = None, target: Optional[str] = None, - version: Optional[str] = None, + version: Optional[str] = DEFAULT_DATASET_REVISION, subset_name: Optional[str] = None, split: Optional[str] = None, data_dir: Optional[str] = None, @@ -150,25 +157,25 @@ class MsDataset: download_mode: Optional[DownloadMode] = None ) -> Union[dict, 'MsDataset']: if isinstance(dataset_name, str): - use_hf = False + dataset_formation = DatasetFormations.native if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir(dataset_name) or \ (os.path.isfile(dataset_name) and dataset_name.endswith('.py')): - use_hf = True + dataset_formation = DatasetFormations.hf_compatible elif is_relative_path(dataset_name) and dataset_name.count( '/') == 0: from modelscope.hub.api import HubApi api = HubApi() - dataset_scripts = api.fetch_dataset_scripts( + dataset_scripts, dataset_formation, download_dir = api.fetch_dataset_scripts( dataset_name, namespace, download_mode, version) - if 'py' in dataset_scripts: # dataset copied from hf datasets - dataset_name = dataset_scripts['py'][0] - use_hf = True + # dataset organized to be compatible with hf format + if dataset_formation == DatasetFormations.hf_compatible: + dataset_name = dataset_scripts['.py'][0] else: raise FileNotFoundError( f"Couldn't find a dataset script at {relative_to_absolute_path(dataset_name)} " f'or any data file in the same directory.') - if use_hf: + if dataset_formation == DatasetFormations.hf_compatible: dataset = hf_load_dataset( dataset_name, name=subset_name, @@ -179,10 +186,16 @@ class MsDataset: cache_dir=MS_DATASETS_CACHE, download_mode=download_mode.value) else: - # TODO load from ms datahub - raise NotImplementedError( - f'Dataset {dataset_name} load from modelscope datahub to be implemented in ' - f'the future') + dataset = MsDataset._load_from_ms( + dataset_name, + dataset_scripts, + download_dir, + namespace=namespace, + version=version, + subset_name=subset_name, + split=split, + download_mode=download_mode, + ) elif isinstance(dataset_name, list): if target is None: target = 'target' @@ -192,6 +205,62 @@ class MsDataset: f' {type(dataset_name)}') return MsDataset.from_hf_dataset(dataset, target=target) + @staticmethod + def _load_from_ms( + dataset_name: str, + dataset_files: dict, + download_dir: str, + namespace: Optional[str] = None, + version: Optional[str] = DEFAULT_DATASET_REVISION, + subset_name: Optional[str] = None, + split: Optional[str] = None, + download_mode: Optional[DownloadMode] = None, + ) -> Union[Dataset, DatasetDict]: + for json_path in dataset_files['.json']: + if json_path.endswith(f'{dataset_name}.json'): + with open(json_path, encoding='utf-8') as dataset_json_file: + dataset_json = json.load(dataset_json_file) + break + target_subset_name, target_dataset_structure = get_target_dataset_structure( + dataset_json, subset_name, split) + meta_map, file_map = get_dataset_files(target_dataset_structure, + dataset_name, namespace, + version) + + builder = load_dataset_builder( + dataset_name, + subset_name, + namespace, + meta_data_files=meta_map, + zip_data_files=file_map, + cache_dir=MS_DATASETS_CACHE, + version=version, + split=list(target_dataset_structure.keys())) + + download_config = DownloadConfig( + cache_dir=download_dir, + force_download=bool( + download_mode == DownloadMode.FORCE_REDOWNLOAD), + force_extract=bool(download_mode == DownloadMode.FORCE_REDOWNLOAD), + use_etag=False, + ) + + dl_manager = DatasetDownloadManager( + dataset_name=dataset_name, + namespace=namespace, + version=version, + download_config=download_config, + data_dir=download_dir, + ) + builder.download_and_prepare( + download_config=download_config, + dl_manager=dl_manager, + download_mode=download_mode.value, + try_from_hf_gcs=False) + + ds = builder.as_dataset() + return ds + def to_torch_dataset_with_processors( self, preprocessors: Union[Callable, List[Callable]], diff --git a/modelscope/msdatasets/utils/__init__.py b/modelscope/msdatasets/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/msdatasets/utils/dataset_builder.py b/modelscope/msdatasets/utils/dataset_builder.py new file mode 100644 index 00000000..2b4bad07 --- /dev/null +++ b/modelscope/msdatasets/utils/dataset_builder.py @@ -0,0 +1,113 @@ +import os +from typing import Mapping, Sequence, Union + +import datasets +import pandas as pd +import pyarrow as pa +from datasets.info import DatasetInfo +from datasets.packaged_modules import csv +from datasets.utils.filelock import FileLock + +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class MsCsvDatasetBuilder(csv.Csv): + + def __init__( + self, + dataset_name: str, + cache_dir: str, + namespace: str, + subset_name: str, + hash: str, + meta_data_files: Mapping[str, Union[str, Sequence[str]]], + zip_data_files: Mapping[str, Union[str, Sequence[str]]] = None, + **config_kwargs, + ): + super().__init__( + cache_dir=cache_dir, + name=subset_name, + hash=hash, + namespace=namespace, + data_files=meta_data_files, + **config_kwargs) + + self.name = dataset_name + self.info.builder_name = self.name + self._cache_dir = self._build_cache_dir() + lock_path = os.path.join( + self._cache_dir_root, + self._cache_dir.replace(os.sep, '_') + '.lock') + with FileLock(lock_path): + # check if data exist + if os.path.exists(self._cache_dir): + if len(os.listdir(self._cache_dir)) > 0: + logger.info( + f'Overwrite dataset info from restored data version, cache_dir is {self._cache_dir}' + ) + self.info = DatasetInfo.from_directory(self._cache_dir) + # dir exists but no data, remove the empty dir as data aren't available anymore + else: + logger.warning( + f'Old caching folder {self._cache_dir} for dataset {self.name} exists ' + f'but not data were found. Removing it. ') + os.rmdir(self._cache_dir) + self.zip_data_files = zip_data_files + + def _build_cache_dir(self): + builder_data_dir = os.path.join( + self._cache_dir_root, + self._relative_data_dir(with_version=False, with_hash=True)) + + return builder_data_dir + + def _split_generators(self, dl_manager): + if not self.config.data_files: + raise ValueError( + 'At least one data file must be specified, but got none.') + data_files = dl_manager.download_and_extract(self.config.data_files) + zip_data_files = dl_manager.download_and_extract(self.zip_data_files) + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + splits.append( + datasets.SplitGenerator( + name=split_name, + gen_kwargs={ + 'files': dl_manager.iter_files(files), + 'base_dir': zip_data_files.get(split_name) + })) + return splits + + def _generate_tables(self, files, base_dir): + schema = pa.schema(self.config.features.type + ) if self.config.features is not None else None + dtype = { + name: dtype.to_pandas_dtype() + for name, dtype in zip(schema.names, schema.types) + } if schema else None + for file_idx, file in enumerate(files): + csv_file_reader = pd.read_csv( + file, + iterator=True, + dtype=dtype, + **self.config.read_csv_kwargs) + transform_fields = [] + for field_name in csv_file_reader._engine.names: + if field_name.endswith(':FILE'): + transform_fields.append(field_name) + try: + for batch_idx, df in enumerate(csv_file_reader): + for field_name in transform_fields: + if base_dir: + df[field_name] = df[field_name].apply( + lambda x: os.path.join(base_dir, x)) + pa_table = pa.Table.from_pandas(df, schema=schema) + yield (file_idx, batch_idx), pa_table + except ValueError as e: + logger.error( + f"Failed to read file '{file}' with error {type(e)}: {e}") + raise diff --git a/modelscope/msdatasets/utils/dataset_utils.py b/modelscope/msdatasets/utils/dataset_utils.py new file mode 100644 index 00000000..ff7cd8b1 --- /dev/null +++ b/modelscope/msdatasets/utils/dataset_utils.py @@ -0,0 +1,113 @@ +import os +from collections import defaultdict +from typing import Mapping, Optional, Sequence, Union + +from datasets.builder import DatasetBuilder + +from modelscope.utils.constant import DEFAULT_DATASET_REVISION +from modelscope.utils.logger import get_logger +from .dataset_builder import MsCsvDatasetBuilder + +logger = get_logger() + + +def get_target_dataset_structure(dataset_structure: dict, + subset_name: Optional[str] = None, + split: Optional[str] = None): + """ + Args: + dataset_structure (dict): Dataset Structure, like + { + "default":{ + "train":{ + "meta":"my_train.csv", + "file":"pictures.zip" + } + }, + "subsetA":{ + "test":{ + "meta":"mytest.csv", + "file":"pictures.zip" + } + } + } + subset_name (str, optional): Defining the subset_name of the dataset. + split (str, optional): Which split of the data to load. + Returns: + target_subset_name (str): Name of the chosen subset. + target_dataset_structure (dict): Structure of the chosen split(s), like + { + "test":{ + "meta":"mytest.csv", + "file":"pictures.zip" + } + } + """ + # verify dataset subset + if (subset_name and subset_name not in dataset_structure) or ( + not subset_name and len(dataset_structure.keys()) > 1): + raise ValueError( + f'subset_name {subset_name} not found. Available: {dataset_structure.keys()}' + ) + target_subset_name = subset_name + if not subset_name: + target_subset_name = next(iter(dataset_structure.keys())) + logger.info( + f'No subset_name specified, defaulting to the {target_subset_name}' + ) + # verify dataset split + target_dataset_structure = dataset_structure[target_subset_name] + if split and split not in target_dataset_structure: + raise ValueError( + f'split {split} not found. Available: {target_dataset_structure.keys()}' + ) + if split: + target_dataset_structure = {split: target_dataset_structure[split]} + return target_subset_name, target_dataset_structure + + +def get_dataset_files(subset_split_into: dict, + dataset_name: str, + namespace: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION): + """ + Return: + meta_map: Structure of meta files (.csv), the meta file name will be replaced by url, like + { + "test": "https://xxx/mytest.csv" + } + file_map: Structure of data files (.zip), like + { + "test": "pictures.zip" + } + """ + meta_map = defaultdict(dict) + file_map = defaultdict(dict) + from modelscope.hub.api import HubApi + modelscope_api = HubApi() + for split, info in subset_split_into.items(): + meta_map[split] = modelscope_api.get_dataset_file_url( + info['meta'], dataset_name, namespace, revision) + if info.get('file'): + file_map[split] = info['file'] + return meta_map, file_map + + +def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str, + meta_data_files: Mapping[str, Union[str, + Sequence[str]]], + zip_data_files: Mapping[str, Union[str, + Sequence[str]]], + cache_dir: str, version: Optional[Union[str]], + split: Sequence[str]) -> DatasetBuilder: + sub_dir = os.path.join(version, '_'.join(split)) + builder_instance = MsCsvDatasetBuilder( + dataset_name=dataset_name, + namespace=namespace, + cache_dir=cache_dir, + subset_name=subset_name, + meta_data_files=meta_data_files, + zip_data_files=zip_data_files, + hash=sub_dir) + + return builder_instance diff --git a/modelscope/msdatasets/utils/download_utils.py b/modelscope/msdatasets/utils/download_utils.py new file mode 100644 index 00000000..bc637f0e --- /dev/null +++ b/modelscope/msdatasets/utils/download_utils.py @@ -0,0 +1,41 @@ +from typing import Optional + +from datasets.utils.download_manager import DownloadConfig, DownloadManager +from datasets.utils.file_utils import cached_path, is_relative_path + +from .oss_utils import OssUtilities + + +class DatasetDownloadManager(DownloadManager): + + def __init__( + self, + dataset_name: str, + namespace: str, + version: str, + data_dir: Optional[str] = None, + download_config: Optional[DownloadConfig] = None, + base_path: Optional[str] = None, + record_checksums=True, + ): + super().__init__(dataset_name, data_dir, download_config, base_path, + record_checksums) + self._namespace = namespace + self._version = version + from modelscope.hub.api import HubApi + api = HubApi() + oss_config = api.get_dataset_access_config(self._dataset_name, + self._namespace, + self._version) + self.oss_utilities = OssUtilities(oss_config) + + def _download(self, url_or_filename: str, + download_config: DownloadConfig) -> str: + url_or_filename = str(url_or_filename) + if is_relative_path(url_or_filename): + # fetch oss files + return self.oss_utilities.download(url_or_filename, + self.download_config.cache_dir) + else: + return cached_path( + url_or_filename, download_config=download_config) diff --git a/modelscope/msdatasets/utils/oss_utils.py b/modelscope/msdatasets/utils/oss_utils.py new file mode 100644 index 00000000..83cfc7dd --- /dev/null +++ b/modelscope/msdatasets/utils/oss_utils.py @@ -0,0 +1,37 @@ +from __future__ import print_function +import os +import sys + +import oss2 +from datasets.utils.file_utils import hash_url_to_filename + + +class OssUtilities: + + def __init__(self, oss_config): + self.key = oss_config['AccessId'] + self.secret = oss_config['AccessSecret'] + self.token = oss_config['SecurityToken'] + self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" + self.bucket_name = oss_config['Bucket'] + auth = oss2.StsAuth(self.key, self.secret, self.token) + self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) + self.oss_dir = oss_config['Dir'] + self.oss_backup_dir = oss_config['BackupDir'] + + def download(self, oss_file_name, cache_dir): + candidate_key = os.path.join(self.oss_dir, oss_file_name) + candidate_key_backup = os.path.join(self.oss_backup_dir, oss_file_name) + file_oss_key = candidate_key if self.bucket.object_exists( + candidate_key) else candidate_key_backup + filename = hash_url_to_filename(file_oss_key, etag=None) + local_path = os.path.join(cache_dir, filename) + + def percentage(consumed_bytes, total_bytes): + if total_bytes: + rate = int(100 * (float(consumed_bytes) / float(total_bytes))) + print('\r{0}% '.format(rate), end='', flush=True) + + self.bucket.get_object_to_file( + file_oss_key, local_path, progress_callback=percentage) + return local_path diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 1eab664c..6bac48ee 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -152,6 +152,23 @@ class DownloadMode(enum.Enum): FORCE_REDOWNLOAD = 'force_redownload' +class DatasetFormations(enum.Enum): + """ How a dataset is organized and interpreted + """ + # formation that is compatible with official huggingface dataset, which + # organizes whole dataset into one single (zip) file. + hf_compatible = 1 + # native modelscope formation that supports, among other things, + # multiple files in a dataset + native = 2 + + +DatasetMetaFormats = { + DatasetFormations.native: ['.json'], + DatasetFormations.hf_compatible: ['.py'], +} + + class ModelFile(object): CONFIGURATION = 'configuration.json' README = 'README.md' diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 0542dc92..fbf33854 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -6,6 +6,7 @@ filelock>=3.3.0 gast>=0.2.2 numpy opencv-python +oss2 Pillow>=6.2.0 pyyaml requests diff --git a/tests/msdatasets/test_ms_dataset.py b/tests/msdatasets/test_ms_dataset.py index 08a05e9c..0894ce3d 100644 --- a/tests/msdatasets/test_ms_dataset.py +++ b/tests/msdatasets/test_ms_dataset.py @@ -1,7 +1,5 @@ import unittest -import datasets as hfdata - from modelscope.models import Model from modelscope.msdatasets import MsDataset from modelscope.preprocessors import SequenceClassificationPreprocessor @@ -32,6 +30,12 @@ class ImgPreprocessor(Preprocessor): class MsDatasetTest(unittest.TestCase): + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ms_csv_basic(self): + ms_ds_train = MsDataset.load( + 'afqmc_small', namespace='userxiaoming', split='train') + print(next(iter(ms_ds_train))) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ds_basic(self): ms_ds_full = MsDataset.load( diff --git a/tests/trainers/test_text_generation_trainer.py b/tests/trainers/test_text_generation_trainer.py index 28f08c97..7c24bc0a 100644 --- a/tests/trainers/test_text_generation_trainer.py +++ b/tests/trainers/test_text_generation_trainer.py @@ -21,10 +21,10 @@ class TestTextGenerationTrainer(unittest.TestCase): if not os.path.exists(self.tmp_dir): os.makedirs(self.tmp_dir) - from datasets import Dataset - self.model_id = 'damo/nlp_palm2.0_text-generation_english-base' + # todo: Replace below scripts with MsDataset.load when the formal dataset service is ready + from datasets import Dataset dataset_dict = { 'src_txt': [ 'This is test sentence1-1', 'This is test sentence2-1', diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index 93d13065..a28bc9e9 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -23,6 +23,7 @@ class TestTrainerWithNlp(unittest.TestCase): if not os.path.exists(self.tmp_dir): os.makedirs(self.tmp_dir) + # todo: Replace below scripts with MsDataset.load when the formal dataset service is ready from datasets import Dataset dataset_dict = { 'sentence1': [