Browse Source

add chn doc for OBSMindDataset

r1.7
liyong 4 years ago
parent
commit
5a1f3480a5
6 changed files with 66 additions and 21 deletions
  1. +45
    -0
      docs/api/api_python/dataset/mindspore.dataset.OBSMindDataset.rst
  2. +14
    -0
      docs/api/api_python/dataset/mindspore.dataset.sync_wait_dataset.rst
  3. +3
    -1
      docs/api/api_python/mindspore.dataset.rst
  4. +0
    -2
      mindspore/python/mindspore/dataset/engine/datasets_standard_format.py
  5. +3
    -3
      mindspore/python/mindspore/dataset/engine/obs/config_loader.py
  6. +1
    -15
      mindspore/python/mindspore/dataset/engine/obs/util.py

+ 45
- 0
docs/api/api_python/dataset/mindspore.dataset.OBSMindDataset.rst View File

@@ -0,0 +1,45 @@
mindspore.dataset.OBSMindDataset
================================
.. py:class:: mindspore.dataset.OBSMindDataset(dataset_files, server, ak, sk, sync_obs_path, columns_list=None,
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=True)
读取和解析存放在OBS上的MindRecord格式数据集。生成的数据集的列名和列类型取决于MindRecord文件中的保存的列名与类型。
**参数:**
- **dataset_files** (list[str]) - OBS上MindRecord格式数据集文件的路径列表,每个文件的路径前缀为s3://。
- **server** (str) - 连接OBS的服务地址。可包含协议类型、域名、端口号。示例: <https://your-endpoint:9000>。
- **ak** (str) - 访问密钥中的AK。
- **sk** (str) - 访问密钥中的SK。
- **sync_obs_path** (str) - 用于同步操作的OBS路径,用户需要提前创建,目录路径的前缀为s3://。
- **columns_list** (list[str],可选) - 指定从MindRecord文件中读取的数据列。默认值:None,读取所有列。
- **shuffle** (Union[bool, Shuffle], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定,默认值:mindspore.dataset.Shuffle.GLOBAL。
如果 `shuffle` 为False,则不混洗,如果 `shuffle` 为True,等同于将 `shuffle` 设置为mindspore.dataset.Shuffle.GLOBAL。
通过传入枚举变量设置数据混洗的模式:
- **Shuffle.GLOBAL**:混洗文件和文件中的数据。
- **Shuffle.FILES**:仅混洗文件。
- **Shuffle.INFILE**:保持读入文件的序列,仅混洗每个文件中的数据。
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数,默认值:None。
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号,默认值:None。只有当指定了 `num_shards` 时才能指定此参数。
- **shard_equal_rows** (bool, 可选) - 分布式训练时,为所有分片获取等量的数据行数。默认值:True。
如果 `shard_equal_rows` 为False,则可能会使得每个分片的数据条目不相等,从而导致分布式训练失败。
因此当每个TFRecord文件的数据数量不相等时,建议将此参数设置为True。注意,只有当指定了 `num_shards` 时才能指定此参数。
**异常:**
- **RuntimeError** - `sync_obs_path` 参数指定的目录不存在。
- **ValueError** - `columns_list` 参数无效。
- **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
- **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
- **ValueError** - `shard_id` 参数值错误(小于0或者大于等于 `num_shards` )。
.. note::
- 需要用户提前在OBS上创建同步用的目录,然后通过 `sync_obs_path` 指定。
- 如果线下训练,建议为每次训练设置 `BATCH_JOB_ID` 环境变量。
- 分布式训练中,假如使用多个节点(服务器),则必须使用每个节点全部的8张卡。如果只有一个节点(服务器),则没有这样的限制。
.. include:: mindspore.dataset.Dataset.sync_wait_dataset.rst

+ 14
- 0
docs/api/api_python/dataset/mindspore.dataset.sync_wait_dataset.rst View File

@@ -0,0 +1,14 @@
mindspore.dataset.sync_wait_dataset
===================================
.. py:function:: mindspore.dataset.sync_wait_dataset(rank_id, rank_size, current_epoch)
等待所有的卡需要的数据集文件下载完成。
.. note:: 需要配合 `mindspore.dataset.OBSMindDataset` 使用,建议在每次epoch开始前调用。
**参数:**
- **rank_id** (int) - 当前卡的逻辑序号。
- **rank_size** (int) - 卡的数量。
- **current_epoch** (int) - 训练时当前的epoch数。

+ 3
- 1
docs/api/api_python/mindspore.dataset.rst View File

@@ -110,6 +110,7 @@ mindspore.dataset

mindspore.dataset.CSVDataset
mindspore.dataset.MindDataset
mindspore.dataset.OBSMindDataset
mindspore.dataset.TFRecordDataset

用户自定义
@@ -130,7 +131,7 @@ mindspore.dataset

.. mscnautosummary::
:toctree: dataset
mindspore.dataset.GraphData


@@ -169,5 +170,6 @@ mindspore.dataset
mindspore.dataset.deserialize
mindspore.dataset.serialize
mindspore.dataset.show
mindspore.dataset.sync_wait_dataset
mindspore.dataset.utils.imshow_det_bbox
mindspore.dataset.zip

+ 0
- 2
mindspore/python/mindspore/dataset/engine/datasets_standard_format.py View File

@@ -332,7 +332,6 @@ class OBSMindDataset(GeneratorDataset):
The columns of generated dataset depend on the source MindRecord files.

Args:

dataset_files (list[str]): List of files in OBS to be read and file path is in
the format of s3://.
server (str): Endpoint for accessing OBS. For example: <https://your-endpoint:9000>.
@@ -383,7 +382,6 @@ class OBSMindDataset(GeneratorDataset):
>>> sync_obs_dir = "s3://sync-dir"
>>> dataset = ds.MindDataset(dataset_obs_dir, "https://your-endpoint:9000", "AK of OBS", "SK of OBS",
... sync_obs_dir, shuffle=True, num_shards=num_shards, shard_id=shard_id)

"""

@check_obsminddataset


+ 3
- 3
mindspore/python/mindspore/dataset/engine/obs/config_loader.py View File

@@ -25,7 +25,7 @@ class Config():
DATASET_LOCAL_PATH = os.path.join(WORKING_PATH, "dataset")
DISK_THRESHOLD = 0.75
TASK_NUM = 8
PART_SIZE = 10*1024*1024
PART_SIZE = 10 * 1024 * 1024
MAX_RETRY = 3
RETRY_DELTA_TIME = 10

@@ -39,8 +39,8 @@ class _Config:
""" Internal class that get and set global variables. """

def __init__(self):
self.config = dict((k, v) for k, v in Config.__dict__.items(
) if not callable(v) and not k.startswith('__'))
self.config = dict((k, v) for k, v in Config.__dict__.items()
if not callable(v) and not k.startswith('__'))

def __getattr__(self, key):
if key in os.environ:


+ 1
- 15
mindspore/python/mindspore/dataset/engine/obs/util.py View File

@@ -59,7 +59,6 @@ def try_load_from_obs(remote_path, dataset_file, local_path):
remote_path (str): OBS path of dataset files.
dataset_file (str): Name of dataset file.
local_path (str): Local path of dataset files.

"""

if not os.path.exists(os.path.join(local_path, dataset_file)):
@@ -76,7 +75,6 @@ def detect_all_meta_files(meta_files, local_path):
Args:
meta_files (List[str]): Names of meta files.
local_path (str): Local path of dataset files.

"""

all_meta_files = True
@@ -98,10 +96,8 @@ def make_sampler(shuffle, is_full_dataset, start, end):
is_full_dataset (bool): Whether to include full dataset file.
start (int): Start index of sample for non-full dataset file.
end (int): End index of sample for non-full dataset file.

"""


sampler = None
if shuffle in (Shuffle.GLOBAL, Shuffle.INFILE):
if is_full_dataset:
@@ -124,8 +120,8 @@ def make_shard_samples(dataset_file_size_list, size_per_shard, shard_id):
dataset_file_size_list (List[tuple]): List of dataset file name and size.
size_per_shard (int): Size of each sharding.
shard_id (int): ID of sharding.

"""

pre_cnt = 0
shard_files = []
finish = False
@@ -169,7 +165,6 @@ def make_dataset_tuple(dataset_files, local_path):
Args:
dataset_files (List[str]): Full paths of dataset files.
local_path (str): Local directory path of dataset files.

"""

dataset_file_size_list = []
@@ -199,7 +194,6 @@ def fetch_meta_files(meta_files, local_path):
Args:
meta_files (List[str]): Full paths of meta files.
local_path (str): Local directory path of dataset files.

"""

for df in meta_files:
@@ -217,7 +211,6 @@ def make_shard_files(dataset_files, num_shards, shard_id):
dataset_files (List[str]): Names of dataset files.
num_shards (int): Number of all sharding.
sharding (int): ID of sharding.

"""

idx = 0
@@ -239,7 +232,6 @@ def get_bucket_and_key(obs_path):

Returns:
bucketName and objectKey.

"""

start = obs_path.find('//')
@@ -322,7 +314,6 @@ def _check_file_exists_in_obs(obs_path):

Args:
obs_path (str): OBS path of dataset file.

"""

bucket_name, object_key = get_bucket_and_key(obs_path)
@@ -353,7 +344,6 @@ def _file_download_from_obs(obs_path, local_path):
Args:
obs_path (str): OBS path of dataset file.
local_path (str): Local path of dataset file.

"""

bucket_name, object_key = get_bucket_and_key(obs_path)
@@ -382,7 +372,6 @@ def _download_file(remote_path, object_name, des_path, lock_file='tmp'):
object_name (str): Name of dataset file.
des_path (str): Local directory path which dataset file is stored.
lock_file (str): File name to lock.

"""

local_path = os.path.join(des_path, object_name)
@@ -409,7 +398,6 @@ def init_cache_and_queue(cache, q, path, shard_file, idx, is_full_dataset, lock_
idx (int): Index of dataset file.
is_full_dataset (bool): Whether to include full dataset file.
lock_file (str): File name to lock.

"""

dataset_file = os.path.basename(shard_file)
@@ -433,7 +421,6 @@ def _detect_file_exist(local_path, meta_file, lock_file='tmp'):
local_path (str): Local directory path of meta file.
meta_file (str): Name of meta file.
lock_file (str): File name to lock.

"""
if os.path.exists(os.path.join(local_path, meta_file)):
return True
@@ -449,7 +436,6 @@ def file_upload_to_obs(obs_path, sync_dir, ready_file_name):
obs_path (str): OBS path of dataset file.
sync_fir (str): OBS directory path used for synchronization.
ready_file_name (str): Name of synchronization file.

"""

bucket_name, object_key = get_bucket_and_key(obs_path)


Loading…
Cancel
Save