Browse Source

[to #43660556] msdataset数据集加载

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9552632

* load csv dataset from modelscoop
master
feiwu.yfw 3 years ago
parent
commit
743e876981
12 changed files with 458 additions and 31 deletions
  1. +43
    -12
      modelscope/hub/api.py
  2. +84
    -15
      modelscope/msdatasets/ms_dataset.py
  3. +0
    -0
      modelscope/msdatasets/utils/__init__.py
  4. +113
    -0
      modelscope/msdatasets/utils/dataset_builder.py
  5. +113
    -0
      modelscope/msdatasets/utils/dataset_utils.py
  6. +41
    -0
      modelscope/msdatasets/utils/download_utils.py
  7. +37
    -0
      modelscope/msdatasets/utils/oss_utils.py
  8. +17
    -0
      modelscope/utils/constant.py
  9. +1
    -0
      requirements/runtime.txt
  10. +6
    -2
      tests/msdatasets/test_ms_dataset.py
  11. +2
    -2
      tests/trainers/test_text_generation_trainer.py
  12. +1
    -0
      tests/trainers/test_trainer_with_nlp.py

+ 43
- 12
modelscope/hub/api.py View File

@@ -12,7 +12,9 @@ import requests
from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH,
HUB_DATASET_ENDPOINT)
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION, DownloadMode)
DEFAULT_MODEL_REVISION,
DatasetFormations, DatasetMetaFormats,
DownloadMode)
from modelscope.utils.logger import get_logger
from .errors import (InvalidParameter, NotExistError, RequestError,
datahub_raise_on_error, handle_http_response, is_ok,
@@ -301,8 +303,8 @@ class HubApi:
f'Dataset from Hubs.modelscope should have a valid "namespace", but get {namespace}'
)
revision = revision or DEFAULT_DATASET_REVISION
cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, dataset_name,
namespace, revision)
cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, namespace,
dataset_name, revision)
download_mode = DownloadMode(download_mode
or DownloadMode.REUSE_DATASET_IF_EXISTS)
if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(
@@ -314,6 +316,7 @@ class HubApi:
resp = r.json()
datahub_raise_on_error(datahub_url, resp)
dataset_id = resp['Data']['Id']
dataset_type = resp['Data']['Type']
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
r = requests.get(datahub_url)
resp = r.json()
@@ -326,25 +329,53 @@ class HubApi:

file_list = file_list['Files']
local_paths = defaultdict(list)
dataset_formation = DatasetFormations(dataset_type)
dataset_meta_format = DatasetMetaFormats[dataset_formation]
for file_info in file_list:
file_path = file_info['Path']
if file_path.endswith('.py'):
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/files?' \
f'Revision={revision}&Path={file_path}'
extension = os.path.splitext(file_path)[-1]
if extension in dataset_meta_format:
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
f'Revision={revision}&FilePath={file_path}'
r = requests.get(datahub_url)
r.raise_for_status()
content = r.json()['Data']['Content']
local_path = os.path.join(cache_dir, file_path)
if os.path.exists(local_path):
logger.warning(
f"Reusing dataset {dataset_name}'s python file ({local_path})"
)
local_paths['py'].append(local_path)
local_paths[extension].append(local_path)
continue
with open(local_path, 'w') as f:
f.writelines(content)
local_paths['py'].append(local_path)
return local_paths
with open(local_path, 'wb') as f:
f.write(r.content)
local_paths[extension].append(local_path)

return local_paths, dataset_formation, cache_dir

def get_dataset_file_url(
self,
file_name: str,
dataset_name: str,
namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION):
return f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
f'Revision={revision}&FilePath={file_name}'

def get_dataset_access_config(
self,
dataset_name: str,
namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION):
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
f'ststoken?Revision={revision}'
return self.datahub_remote_call(datahub_url)

@staticmethod
def datahub_remote_call(url):
r = requests.get(url)
resp = r.json()
datahub_raise_on_error(url, resp)
return resp['Data']


class ModelScopeConfig:


+ 84
- 15
modelscope/msdatasets/ms_dataset.py View File

@@ -2,17 +2,24 @@ import os
from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
Sequence, Union)

import json
import numpy as np
from datasets import Dataset, DatasetDict
from datasets import load_dataset as hf_load_dataset
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
from datasets.utils.download_manager import DownloadConfig
from datasets.utils.file_utils import (is_relative_path,
relative_to_absolute_path)

from modelscope.msdatasets.config import MS_DATASETS_CACHE
from modelscope.utils.constant import DownloadMode, Hubs
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DatasetFormations, DownloadMode, Hubs)
from modelscope.utils.logger import get_logger
from .utils.dataset_utils import (get_dataset_files,
get_target_dataset_structure,
load_dataset_builder)
from .utils.download_utils import DatasetDownloadManager

logger = get_logger()

@@ -80,7 +87,7 @@ class MsDataset:
dataset_name: Union[str, list],
namespace: Optional[str] = None,
target: Optional[str] = None,
version: Optional[str] = None,
version: Optional[str] = DEFAULT_DATASET_REVISION,
hub: Optional[Hubs] = Hubs.modelscope,
subset_name: Optional[str] = None,
split: Optional[str] = None,
@@ -95,7 +102,7 @@ class MsDataset:
Args:

dataset_name (str): Path or name of the dataset.
namespace(str, optional): Namespace of the dataset. It should not be None, if you load a remote dataset
namespace(str, optional): Namespace of the dataset. It should not be None if you load a remote dataset
from Hubs.modelscope,
target (str, optional): Name of the column to output.
version (str, optional): Version of the dataset script to load:
@@ -140,7 +147,7 @@ class MsDataset:
dataset_name: Union[str, list],
namespace: Optional[str] = None,
target: Optional[str] = None,
version: Optional[str] = None,
version: Optional[str] = DEFAULT_DATASET_REVISION,
subset_name: Optional[str] = None,
split: Optional[str] = None,
data_dir: Optional[str] = None,
@@ -150,25 +157,25 @@ class MsDataset:
download_mode: Optional[DownloadMode] = None
) -> Union[dict, 'MsDataset']:
if isinstance(dataset_name, str):
use_hf = False
dataset_formation = DatasetFormations.native
if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir(dataset_name) or \
(os.path.isfile(dataset_name) and dataset_name.endswith('.py')):
use_hf = True
dataset_formation = DatasetFormations.hf_compatible
elif is_relative_path(dataset_name) and dataset_name.count(
'/') == 0:
from modelscope.hub.api import HubApi
api = HubApi()
dataset_scripts = api.fetch_dataset_scripts(
dataset_scripts, dataset_formation, download_dir = api.fetch_dataset_scripts(
dataset_name, namespace, download_mode, version)
if 'py' in dataset_scripts: # dataset copied from hf datasets
dataset_name = dataset_scripts['py'][0]
use_hf = True
# dataset organized to be compatible with hf format
if dataset_formation == DatasetFormations.hf_compatible:
dataset_name = dataset_scripts['.py'][0]
else:
raise FileNotFoundError(
f"Couldn't find a dataset script at {relative_to_absolute_path(dataset_name)} "
f'or any data file in the same directory.')

if use_hf:
if dataset_formation == DatasetFormations.hf_compatible:
dataset = hf_load_dataset(
dataset_name,
name=subset_name,
@@ -179,10 +186,16 @@ class MsDataset:
cache_dir=MS_DATASETS_CACHE,
download_mode=download_mode.value)
else:
# TODO load from ms datahub
raise NotImplementedError(
f'Dataset {dataset_name} load from modelscope datahub to be implemented in '
f'the future')
dataset = MsDataset._load_from_ms(
dataset_name,
dataset_scripts,
download_dir,
namespace=namespace,
version=version,
subset_name=subset_name,
split=split,
download_mode=download_mode,
)
elif isinstance(dataset_name, list):
if target is None:
target = 'target'
@@ -192,6 +205,62 @@ class MsDataset:
f' {type(dataset_name)}')
return MsDataset.from_hf_dataset(dataset, target=target)

@staticmethod
def _load_from_ms(
dataset_name: str,
dataset_files: dict,
download_dir: str,
namespace: Optional[str] = None,
version: Optional[str] = DEFAULT_DATASET_REVISION,
subset_name: Optional[str] = None,
split: Optional[str] = None,
download_mode: Optional[DownloadMode] = None,
) -> Union[Dataset, DatasetDict]:
for json_path in dataset_files['.json']:
if json_path.endswith(f'{dataset_name}.json'):
with open(json_path, encoding='utf-8') as dataset_json_file:
dataset_json = json.load(dataset_json_file)
break
target_subset_name, target_dataset_structure = get_target_dataset_structure(
dataset_json, subset_name, split)
meta_map, file_map = get_dataset_files(target_dataset_structure,
dataset_name, namespace,
version)

builder = load_dataset_builder(
dataset_name,
subset_name,
namespace,
meta_data_files=meta_map,
zip_data_files=file_map,
cache_dir=MS_DATASETS_CACHE,
version=version,
split=list(target_dataset_structure.keys()))

download_config = DownloadConfig(
cache_dir=download_dir,
force_download=bool(
download_mode == DownloadMode.FORCE_REDOWNLOAD),
force_extract=bool(download_mode == DownloadMode.FORCE_REDOWNLOAD),
use_etag=False,
)

dl_manager = DatasetDownloadManager(
dataset_name=dataset_name,
namespace=namespace,
version=version,
download_config=download_config,
data_dir=download_dir,
)
builder.download_and_prepare(
download_config=download_config,
dl_manager=dl_manager,
download_mode=download_mode.value,
try_from_hf_gcs=False)

ds = builder.as_dataset()
return ds

def to_torch_dataset_with_processors(
self,
preprocessors: Union[Callable, List[Callable]],


+ 0
- 0
modelscope/msdatasets/utils/__init__.py View File


+ 113
- 0
modelscope/msdatasets/utils/dataset_builder.py View File

@@ -0,0 +1,113 @@
import os
from typing import Mapping, Sequence, Union

import datasets
import pandas as pd
import pyarrow as pa
from datasets.info import DatasetInfo
from datasets.packaged_modules import csv
from datasets.utils.filelock import FileLock

from modelscope.utils.logger import get_logger

logger = get_logger()


class MsCsvDatasetBuilder(csv.Csv):

def __init__(
self,
dataset_name: str,
cache_dir: str,
namespace: str,
subset_name: str,
hash: str,
meta_data_files: Mapping[str, Union[str, Sequence[str]]],
zip_data_files: Mapping[str, Union[str, Sequence[str]]] = None,
**config_kwargs,
):
super().__init__(
cache_dir=cache_dir,
name=subset_name,
hash=hash,
namespace=namespace,
data_files=meta_data_files,
**config_kwargs)

self.name = dataset_name
self.info.builder_name = self.name
self._cache_dir = self._build_cache_dir()
lock_path = os.path.join(
self._cache_dir_root,
self._cache_dir.replace(os.sep, '_') + '.lock')
with FileLock(lock_path):
# check if data exist
if os.path.exists(self._cache_dir):
if len(os.listdir(self._cache_dir)) > 0:
logger.info(
f'Overwrite dataset info from restored data version, cache_dir is {self._cache_dir}'
)
self.info = DatasetInfo.from_directory(self._cache_dir)
# dir exists but no data, remove the empty dir as data aren't available anymore
else:
logger.warning(
f'Old caching folder {self._cache_dir} for dataset {self.name} exists '
f'but not data were found. Removing it. ')
os.rmdir(self._cache_dir)
self.zip_data_files = zip_data_files

def _build_cache_dir(self):
builder_data_dir = os.path.join(
self._cache_dir_root,
self._relative_data_dir(with_version=False, with_hash=True))

return builder_data_dir

def _split_generators(self, dl_manager):
if not self.config.data_files:
raise ValueError(
'At least one data file must be specified, but got none.')
data_files = dl_manager.download_and_extract(self.config.data_files)
zip_data_files = dl_manager.download_and_extract(self.zip_data_files)
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
splits.append(
datasets.SplitGenerator(
name=split_name,
gen_kwargs={
'files': dl_manager.iter_files(files),
'base_dir': zip_data_files.get(split_name)
}))
return splits

def _generate_tables(self, files, base_dir):
schema = pa.schema(self.config.features.type
) if self.config.features is not None else None
dtype = {
name: dtype.to_pandas_dtype()
for name, dtype in zip(schema.names, schema.types)
} if schema else None
for file_idx, file in enumerate(files):
csv_file_reader = pd.read_csv(
file,
iterator=True,
dtype=dtype,
**self.config.read_csv_kwargs)
transform_fields = []
for field_name in csv_file_reader._engine.names:
if field_name.endswith(':FILE'):
transform_fields.append(field_name)
try:
for batch_idx, df in enumerate(csv_file_reader):
for field_name in transform_fields:
if base_dir:
df[field_name] = df[field_name].apply(
lambda x: os.path.join(base_dir, x))
pa_table = pa.Table.from_pandas(df, schema=schema)
yield (file_idx, batch_idx), pa_table
except ValueError as e:
logger.error(
f"Failed to read file '{file}' with error {type(e)}: {e}")
raise

+ 113
- 0
modelscope/msdatasets/utils/dataset_utils.py View File

@@ -0,0 +1,113 @@
import os
from collections import defaultdict
from typing import Mapping, Optional, Sequence, Union

from datasets.builder import DatasetBuilder

from modelscope.utils.constant import DEFAULT_DATASET_REVISION
from modelscope.utils.logger import get_logger
from .dataset_builder import MsCsvDatasetBuilder

logger = get_logger()


def get_target_dataset_structure(dataset_structure: dict,
subset_name: Optional[str] = None,
split: Optional[str] = None):
"""
Args:
dataset_structure (dict): Dataset Structure, like
{
"default":{
"train":{
"meta":"my_train.csv",
"file":"pictures.zip"
}
},
"subsetA":{
"test":{
"meta":"mytest.csv",
"file":"pictures.zip"
}
}
}
subset_name (str, optional): Defining the subset_name of the dataset.
split (str, optional): Which split of the data to load.
Returns:
target_subset_name (str): Name of the chosen subset.
target_dataset_structure (dict): Structure of the chosen split(s), like
{
"test":{
"meta":"mytest.csv",
"file":"pictures.zip"
}
}
"""
# verify dataset subset
if (subset_name and subset_name not in dataset_structure) or (
not subset_name and len(dataset_structure.keys()) > 1):
raise ValueError(
f'subset_name {subset_name} not found. Available: {dataset_structure.keys()}'
)
target_subset_name = subset_name
if not subset_name:
target_subset_name = next(iter(dataset_structure.keys()))
logger.info(
f'No subset_name specified, defaulting to the {target_subset_name}'
)
# verify dataset split
target_dataset_structure = dataset_structure[target_subset_name]
if split and split not in target_dataset_structure:
raise ValueError(
f'split {split} not found. Available: {target_dataset_structure.keys()}'
)
if split:
target_dataset_structure = {split: target_dataset_structure[split]}
return target_subset_name, target_dataset_structure


def get_dataset_files(subset_split_into: dict,
dataset_name: str,
namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION):
"""
Return:
meta_map: Structure of meta files (.csv), the meta file name will be replaced by url, like
{
"test": "https://xxx/mytest.csv"
}
file_map: Structure of data files (.zip), like
{
"test": "pictures.zip"
}
"""
meta_map = defaultdict(dict)
file_map = defaultdict(dict)
from modelscope.hub.api import HubApi
modelscope_api = HubApi()
for split, info in subset_split_into.items():
meta_map[split] = modelscope_api.get_dataset_file_url(
info['meta'], dataset_name, namespace, revision)
if info.get('file'):
file_map[split] = info['file']
return meta_map, file_map


def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str,
meta_data_files: Mapping[str, Union[str,
Sequence[str]]],
zip_data_files: Mapping[str, Union[str,
Sequence[str]]],
cache_dir: str, version: Optional[Union[str]],
split: Sequence[str]) -> DatasetBuilder:
sub_dir = os.path.join(version, '_'.join(split))
builder_instance = MsCsvDatasetBuilder(
dataset_name=dataset_name,
namespace=namespace,
cache_dir=cache_dir,
subset_name=subset_name,
meta_data_files=meta_data_files,
zip_data_files=zip_data_files,
hash=sub_dir)

return builder_instance

+ 41
- 0
modelscope/msdatasets/utils/download_utils.py View File

@@ -0,0 +1,41 @@
from typing import Optional

from datasets.utils.download_manager import DownloadConfig, DownloadManager
from datasets.utils.file_utils import cached_path, is_relative_path

from .oss_utils import OssUtilities


class DatasetDownloadManager(DownloadManager):

def __init__(
self,
dataset_name: str,
namespace: str,
version: str,
data_dir: Optional[str] = None,
download_config: Optional[DownloadConfig] = None,
base_path: Optional[str] = None,
record_checksums=True,
):
super().__init__(dataset_name, data_dir, download_config, base_path,
record_checksums)
self._namespace = namespace
self._version = version
from modelscope.hub.api import HubApi
api = HubApi()
oss_config = api.get_dataset_access_config(self._dataset_name,
self._namespace,
self._version)
self.oss_utilities = OssUtilities(oss_config)

def _download(self, url_or_filename: str,
download_config: DownloadConfig) -> str:
url_or_filename = str(url_or_filename)
if is_relative_path(url_or_filename):
# fetch oss files
return self.oss_utilities.download(url_or_filename,
self.download_config.cache_dir)
else:
return cached_path(
url_or_filename, download_config=download_config)

+ 37
- 0
modelscope/msdatasets/utils/oss_utils.py View File

@@ -0,0 +1,37 @@
from __future__ import print_function
import os
import sys

import oss2
from datasets.utils.file_utils import hash_url_to_filename


class OssUtilities:

def __init__(self, oss_config):
self.key = oss_config['AccessId']
self.secret = oss_config['AccessSecret']
self.token = oss_config['SecurityToken']
self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com"
self.bucket_name = oss_config['Bucket']
auth = oss2.StsAuth(self.key, self.secret, self.token)
self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name)
self.oss_dir = oss_config['Dir']
self.oss_backup_dir = oss_config['BackupDir']

def download(self, oss_file_name, cache_dir):
candidate_key = os.path.join(self.oss_dir, oss_file_name)
candidate_key_backup = os.path.join(self.oss_backup_dir, oss_file_name)
file_oss_key = candidate_key if self.bucket.object_exists(
candidate_key) else candidate_key_backup
filename = hash_url_to_filename(file_oss_key, etag=None)
local_path = os.path.join(cache_dir, filename)

def percentage(consumed_bytes, total_bytes):
if total_bytes:
rate = int(100 * (float(consumed_bytes) / float(total_bytes)))
print('\r{0}% '.format(rate), end='', flush=True)

self.bucket.get_object_to_file(
file_oss_key, local_path, progress_callback=percentage)
return local_path

+ 17
- 0
modelscope/utils/constant.py View File

@@ -152,6 +152,23 @@ class DownloadMode(enum.Enum):
FORCE_REDOWNLOAD = 'force_redownload'


class DatasetFormations(enum.Enum):
""" How a dataset is organized and interpreted
"""
# formation that is compatible with official huggingface dataset, which
# organizes whole dataset into one single (zip) file.
hf_compatible = 1
# native modelscope formation that supports, among other things,
# multiple files in a dataset
native = 2


DatasetMetaFormats = {
DatasetFormations.native: ['.json'],
DatasetFormations.hf_compatible: ['.py'],
}


class ModelFile(object):
CONFIGURATION = 'configuration.json'
README = 'README.md'


+ 1
- 0
requirements/runtime.txt View File

@@ -6,6 +6,7 @@ filelock>=3.3.0
gast>=0.2.2
numpy
opencv-python
oss2
Pillow>=6.2.0
pyyaml
requests


+ 6
- 2
tests/msdatasets/test_ms_dataset.py View File

@@ -1,7 +1,5 @@
import unittest

import datasets as hfdata

from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.preprocessors import SequenceClassificationPreprocessor
@@ -32,6 +30,12 @@ class ImgPreprocessor(Preprocessor):

class MsDatasetTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ms_csv_basic(self):
ms_ds_train = MsDataset.load(
'afqmc_small', namespace='userxiaoming', split='train')
print(next(iter(ms_ds_train)))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ds_basic(self):
ms_ds_full = MsDataset.load(


+ 2
- 2
tests/trainers/test_text_generation_trainer.py View File

@@ -21,10 +21,10 @@ class TestTextGenerationTrainer(unittest.TestCase):
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

from datasets import Dataset

self.model_id = 'damo/nlp_palm2.0_text-generation_english-base'

# todo: Replace below scripts with MsDataset.load when the formal dataset service is ready
from datasets import Dataset
dataset_dict = {
'src_txt': [
'This is test sentence1-1', 'This is test sentence2-1',


+ 1
- 0
tests/trainers/test_trainer_with_nlp.py View File

@@ -23,6 +23,7 @@ class TestTrainerWithNlp(unittest.TestCase):
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

# todo: Replace below scripts with MsDataset.load when the formal dataset service is ready
from datasets import Dataset
dataset_dict = {
'sentence1': [


Loading…
Cancel
Save