| @@ -0,0 +1,45 @@ | |||
| mindspore.dataset.OBSMindDataset | |||
| ================================ | |||
| .. py:class:: mindspore.dataset.OBSMindDataset(dataset_files, server, ak, sk, sync_obs_path, columns_list=None, | |||
| shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=True) | |||
| 读取和解析存放在OBS上的MindRecord格式数据集。生成的数据集的列名和列类型取决于MindRecord文件中的保存的列名与类型。 | |||
| **参数:** | |||
| - **dataset_files** (list[str]) - OBS上MindRecord格式数据集文件的路径列表,每个文件的路径前缀为s3://。 | |||
| - **server** (str) - 连接OBS的服务地址。可包含协议类型、域名、端口号。示例: <https://your-endpoint:9000>。 | |||
| - **ak** (str) - 访问密钥中的AK。 | |||
| - **sk** (str) - 访问密钥中的SK。 | |||
| - **sync_obs_path** (str) - 用于同步操作的OBS路径,用户需要提前创建,目录路径的前缀为s3://。 | |||
| - **columns_list** (list[str],可选) - 指定从MindRecord文件中读取的数据列。默认值:None,读取所有列。 | |||
| - **shuffle** (Union[bool, Shuffle], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定,默认值:mindspore.dataset.Shuffle.GLOBAL。 | |||
| 如果 `shuffle` 为False,则不混洗,如果 `shuffle` 为True,等同于将 `shuffle` 设置为mindspore.dataset.Shuffle.GLOBAL。 | |||
| 通过传入枚举变量设置数据混洗的模式: | |||
| - **Shuffle.GLOBAL**:混洗文件和文件中的数据。 | |||
| - **Shuffle.FILES**:仅混洗文件。 | |||
| - **Shuffle.INFILE**:保持读入文件的序列,仅混洗每个文件中的数据。 | |||
| - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数,默认值:None。 | |||
| - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号,默认值:None。只有当指定了 `num_shards` 时才能指定此参数。 | |||
| - **shard_equal_rows** (bool, 可选) - 分布式训练时,为所有分片获取等量的数据行数。默认值:True。 | |||
| 如果 `shard_equal_rows` 为False,则可能会使得每个分片的数据条目不相等,从而导致分布式训练失败。 | |||
| 因此当每个TFRecord文件的数据数量不相等时,建议将此参数设置为True。注意,只有当指定了 `num_shards` 时才能指定此参数。 | |||
| **异常:** | |||
| - **RuntimeError** - `sync_obs_path` 参数指定的目录不存在。 | |||
| - **ValueError** - `columns_list` 参数无效。 | |||
| - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 | |||
| - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 | |||
| - **ValueError** - `shard_id` 参数值错误(小于0或者大于等于 `num_shards` )。 | |||
| .. note:: | |||
| - 需要用户提前在OBS上创建同步用的目录,然后通过 `sync_obs_path` 指定。 | |||
| - 如果线下训练,建议为每次训练设置 `BATCH_JOB_ID` 环境变量。 | |||
| - 分布式训练中,假如使用多个节点(服务器),则必须使用每个节点全部的8张卡。如果只有一个节点(服务器),则没有这样的限制。 | |||
| .. include:: mindspore.dataset.Dataset.sync_wait_dataset.rst | |||
| @@ -0,0 +1,14 @@ | |||
| mindspore.dataset.sync_wait_dataset | |||
| =================================== | |||
| .. py:function:: mindspore.dataset.sync_wait_dataset(rank_id, rank_size, current_epoch) | |||
| 等待所有的卡需要的数据集文件下载完成。 | |||
| .. note:: 需要配合 `mindspore.dataset.OBSMindDataset` 使用,建议在每次epoch开始前调用。 | |||
| **参数:** | |||
| - **rank_id** (int) - 当前卡的逻辑序号。 | |||
| - **rank_size** (int) - 卡的数量。 | |||
| - **current_epoch** (int) - 训练时当前的epoch数。 | |||
| @@ -110,6 +110,7 @@ mindspore.dataset | |||
| mindspore.dataset.CSVDataset | |||
| mindspore.dataset.MindDataset | |||
| mindspore.dataset.OBSMindDataset | |||
| mindspore.dataset.TFRecordDataset | |||
| 用户自定义 | |||
| @@ -130,7 +131,7 @@ mindspore.dataset | |||
| .. mscnautosummary:: | |||
| :toctree: dataset | |||
| mindspore.dataset.GraphData | |||
| @@ -169,5 +170,6 @@ mindspore.dataset | |||
| mindspore.dataset.deserialize | |||
| mindspore.dataset.serialize | |||
| mindspore.dataset.show | |||
| mindspore.dataset.sync_wait_dataset | |||
| mindspore.dataset.utils.imshow_det_bbox | |||
| mindspore.dataset.zip | |||
| @@ -332,7 +332,6 @@ class OBSMindDataset(GeneratorDataset): | |||
| The columns of generated dataset depend on the source MindRecord files. | |||
| Args: | |||
| dataset_files (list[str]): List of files in OBS to be read and file path is in | |||
| the format of s3://. | |||
| server (str): Endpoint for accessing OBS. For example: <https://your-endpoint:9000>. | |||
| @@ -383,7 +382,6 @@ class OBSMindDataset(GeneratorDataset): | |||
| >>> sync_obs_dir = "s3://sync-dir" | |||
| >>> dataset = ds.MindDataset(dataset_obs_dir, "https://your-endpoint:9000", "AK of OBS", "SK of OBS", | |||
| ... sync_obs_dir, shuffle=True, num_shards=num_shards, shard_id=shard_id) | |||
| """ | |||
| @check_obsminddataset | |||
| @@ -25,7 +25,7 @@ class Config(): | |||
| DATASET_LOCAL_PATH = os.path.join(WORKING_PATH, "dataset") | |||
| DISK_THRESHOLD = 0.75 | |||
| TASK_NUM = 8 | |||
| PART_SIZE = 10*1024*1024 | |||
| PART_SIZE = 10 * 1024 * 1024 | |||
| MAX_RETRY = 3 | |||
| RETRY_DELTA_TIME = 10 | |||
| @@ -39,8 +39,8 @@ class _Config: | |||
| """ Internal class that get and set global variables. """ | |||
| def __init__(self): | |||
| self.config = dict((k, v) for k, v in Config.__dict__.items( | |||
| ) if not callable(v) and not k.startswith('__')) | |||
| self.config = dict((k, v) for k, v in Config.__dict__.items() | |||
| if not callable(v) and not k.startswith('__')) | |||
| def __getattr__(self, key): | |||
| if key in os.environ: | |||
| @@ -59,7 +59,6 @@ def try_load_from_obs(remote_path, dataset_file, local_path): | |||
| remote_path (str): OBS path of dataset files. | |||
| dataset_file (str): Name of dataset file. | |||
| local_path (str): Local path of dataset files. | |||
| """ | |||
| if not os.path.exists(os.path.join(local_path, dataset_file)): | |||
| @@ -76,7 +75,6 @@ def detect_all_meta_files(meta_files, local_path): | |||
| Args: | |||
| meta_files (List[str]): Names of meta files. | |||
| local_path (str): Local path of dataset files. | |||
| """ | |||
| all_meta_files = True | |||
| @@ -98,10 +96,8 @@ def make_sampler(shuffle, is_full_dataset, start, end): | |||
| is_full_dataset (bool): Whether to include full dataset file. | |||
| start (int): Start index of sample for non-full dataset file. | |||
| end (int): End index of sample for non-full dataset file. | |||
| """ | |||
| sampler = None | |||
| if shuffle in (Shuffle.GLOBAL, Shuffle.INFILE): | |||
| if is_full_dataset: | |||
| @@ -124,8 +120,8 @@ def make_shard_samples(dataset_file_size_list, size_per_shard, shard_id): | |||
| dataset_file_size_list (List[tuple]): List of dataset file name and size. | |||
| size_per_shard (int): Size of each sharding. | |||
| shard_id (int): ID of sharding. | |||
| """ | |||
| pre_cnt = 0 | |||
| shard_files = [] | |||
| finish = False | |||
| @@ -169,7 +165,6 @@ def make_dataset_tuple(dataset_files, local_path): | |||
| Args: | |||
| dataset_files (List[str]): Full paths of dataset files. | |||
| local_path (str): Local directory path of dataset files. | |||
| """ | |||
| dataset_file_size_list = [] | |||
| @@ -199,7 +194,6 @@ def fetch_meta_files(meta_files, local_path): | |||
| Args: | |||
| meta_files (List[str]): Full paths of meta files. | |||
| local_path (str): Local directory path of dataset files. | |||
| """ | |||
| for df in meta_files: | |||
| @@ -217,7 +211,6 @@ def make_shard_files(dataset_files, num_shards, shard_id): | |||
| dataset_files (List[str]): Names of dataset files. | |||
| num_shards (int): Number of all sharding. | |||
| sharding (int): ID of sharding. | |||
| """ | |||
| idx = 0 | |||
| @@ -239,7 +232,6 @@ def get_bucket_and_key(obs_path): | |||
| Returns: | |||
| bucketName and objectKey. | |||
| """ | |||
| start = obs_path.find('//') | |||
| @@ -322,7 +314,6 @@ def _check_file_exists_in_obs(obs_path): | |||
| Args: | |||
| obs_path (str): OBS path of dataset file. | |||
| """ | |||
| bucket_name, object_key = get_bucket_and_key(obs_path) | |||
| @@ -353,7 +344,6 @@ def _file_download_from_obs(obs_path, local_path): | |||
| Args: | |||
| obs_path (str): OBS path of dataset file. | |||
| local_path (str): Local path of dataset file. | |||
| """ | |||
| bucket_name, object_key = get_bucket_and_key(obs_path) | |||
| @@ -382,7 +372,6 @@ def _download_file(remote_path, object_name, des_path, lock_file='tmp'): | |||
| object_name (str): Name of dataset file. | |||
| des_path (str): Local directory path which dataset file is stored. | |||
| lock_file (str): File name to lock. | |||
| """ | |||
| local_path = os.path.join(des_path, object_name) | |||
| @@ -409,7 +398,6 @@ def init_cache_and_queue(cache, q, path, shard_file, idx, is_full_dataset, lock_ | |||
| idx (int): Index of dataset file. | |||
| is_full_dataset (bool): Whether to include full dataset file. | |||
| lock_file (str): File name to lock. | |||
| """ | |||
| dataset_file = os.path.basename(shard_file) | |||
| @@ -433,7 +421,6 @@ def _detect_file_exist(local_path, meta_file, lock_file='tmp'): | |||
| local_path (str): Local directory path of meta file. | |||
| meta_file (str): Name of meta file. | |||
| lock_file (str): File name to lock. | |||
| """ | |||
| if os.path.exists(os.path.join(local_path, meta_file)): | |||
| return True | |||
| @@ -449,7 +436,6 @@ def file_upload_to_obs(obs_path, sync_dir, ready_file_name): | |||
| obs_path (str): OBS path of dataset file. | |||
| sync_fir (str): OBS directory path used for synchronization. | |||
| ready_file_name (str): Name of synchronization file. | |||
| """ | |||
| bucket_name, object_key = get_bucket_and_key(obs_path) | |||