From: @tiancixiao Reviewed-by: @liucunwei,@heleiwang Signed-off-by: @liucunweitags/v1.2.0-rc1
| @@ -102,7 +102,7 @@ def get_seed(): | |||||
| Get the seed. | Get the seed. | ||||
| Returns: | Returns: | ||||
| Int, seed. | |||||
| int, seed. | |||||
| """ | """ | ||||
| return _config.get_seed() | return _config.get_seed() | ||||
| @@ -131,7 +131,7 @@ def get_prefetch_size(): | |||||
| Get the prefetch size in number of rows. | Get the prefetch size in number of rows. | ||||
| Returns: | Returns: | ||||
| Size, total number of rows to be prefetched. | |||||
| int, total number of rows to be prefetched. | |||||
| """ | """ | ||||
| return _config.get_op_connector_size() | return _config.get_op_connector_size() | ||||
| @@ -162,7 +162,7 @@ def get_num_parallel_workers(): | |||||
| This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature. | This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature. | ||||
| Returns: | Returns: | ||||
| Int, number of parallel workers to be used as a default for each operation | |||||
| int, number of parallel workers to be used as a default for each operation. | |||||
| """ | """ | ||||
| return _config.get_num_parallel_workers() | return _config.get_num_parallel_workers() | ||||
| @@ -193,7 +193,7 @@ def get_numa_enable(): | |||||
| This is the DEFAULT numa enabled value used for the all process. | This is the DEFAULT numa enabled value used for the all process. | ||||
| Returns: | Returns: | ||||
| boolean, the default state of numa enabled | |||||
| bool, the default state of numa enabled. | |||||
| """ | """ | ||||
| return _config.get_numa_enable() | return _config.get_numa_enable() | ||||
| @@ -222,7 +222,7 @@ def get_monitor_sampling_interval(): | |||||
| Get the default interval of performance monitor sampling. | Get the default interval of performance monitor sampling. | ||||
| Returns: | Returns: | ||||
| Int, interval (in milliseconds) for performance monitor sampling. | |||||
| int, interval (in milliseconds) for performance monitor sampling. | |||||
| """ | """ | ||||
| return _config.get_monitor_sampling_interval() | return _config.get_monitor_sampling_interval() | ||||
| @@ -280,7 +280,8 @@ def get_auto_num_workers(): | |||||
| Get the setting (turned on or off) automatic number of workers. | Get the setting (turned on or off) automatic number of workers. | ||||
| Returns: | Returns: | ||||
| Bool, whether auto num worker feature is turned on | |||||
| bool, whether auto num worker feature is turned on. | |||||
| Examples: | Examples: | ||||
| >>> num_workers = ds.config.get_auto_num_workers() | >>> num_workers = ds.config.get_auto_num_workers() | ||||
| """ | """ | ||||
| @@ -313,7 +314,7 @@ def get_callback_timeout(): | |||||
| In case of a deadlock, the wait function will exit after the timeout period. | In case of a deadlock, the wait function will exit after the timeout period. | ||||
| Returns: | Returns: | ||||
| Int, the duration in seconds | |||||
| int, the duration in seconds. | |||||
| """ | """ | ||||
| return _config.get_callback_timeout() | return _config.get_callback_timeout() | ||||
| @@ -323,7 +324,7 @@ def __str__(): | |||||
| String representation of the configurations. | String representation of the configurations. | ||||
| Returns: | Returns: | ||||
| Str, configurations. | |||||
| str, configurations. | |||||
| """ | """ | ||||
| return str(_config) | return str(_config) | ||||
| @@ -80,7 +80,7 @@ def zip(datasets): | |||||
| The number of datasets must be more than 1. | The number of datasets must be more than 1. | ||||
| Returns: | Returns: | ||||
| Dataset, ZipDataset. | |||||
| ZipDataset, dataset zipped. | |||||
| Raises: | Raises: | ||||
| ValueError: If the number of datasets is 1. | ValueError: If the number of datasets is 1. | ||||
| @@ -149,8 +149,8 @@ class Dataset: | |||||
| Internal method to create an IR tree. | Internal method to create an IR tree. | ||||
| Returns: | Returns: | ||||
| ir_tree, The onject of the IR tree. | |||||
| dataset, the root dataset of the IR tree. | |||||
| DatasetNode, the root node of the IR tree. | |||||
| Dataset, the root dataset of the IR tree. | |||||
| """ | """ | ||||
| parent = self.parent | parent = self.parent | ||||
| self.parent = [] | self.parent = [] | ||||
| @@ -165,7 +165,7 @@ class Dataset: | |||||
| Internal method to parse the API tree into an IR tree. | Internal method to parse the API tree into an IR tree. | ||||
| Returns: | Returns: | ||||
| DatasetNode, The root of the IR tree. | |||||
| DatasetNode, the root node of the IR tree. | |||||
| """ | """ | ||||
| if len(self.parent) > 1: | if len(self.parent) > 1: | ||||
| raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") | raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") | ||||
| @@ -197,7 +197,7 @@ class Dataset: | |||||
| Args: | Args: | ||||
| Returns: | Returns: | ||||
| Python dictionary. | |||||
| dict, attributes related to the current class. | |||||
| """ | """ | ||||
| args = dict() | args = dict() | ||||
| args["num_parallel_workers"] = self.num_parallel_workers | args["num_parallel_workers"] = self.num_parallel_workers | ||||
| @@ -211,7 +211,7 @@ class Dataset: | |||||
| filename (str): filename of json file to be saved as | filename (str): filename of json file to be saved as | ||||
| Returns: | Returns: | ||||
| Str, JSON string of the pipeline. | |||||
| str, JSON string of the pipeline. | |||||
| """ | """ | ||||
| return json.loads(self.parse_tree().to_json(filename)) | return json.loads(self.parse_tree().to_json(filename)) | ||||
| @@ -258,6 +258,9 @@ class Dataset: | |||||
| drop_remainder (bool, optional): If True, will drop the last batch for each | drop_remainder (bool, optional): If True, will drop the last batch for each | ||||
| bucket if it is not a full batch (default=False). | bucket if it is not a full batch (default=False). | ||||
| Returns: | |||||
| BucketBatchByLengthDataset, dataset bucketed and batched by length. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| >>> | >>> | ||||
| @@ -371,6 +374,9 @@ class Dataset: | |||||
| num_batch (int): the number of batches without blocking at the start of each epoch. | num_batch (int): the number of batches without blocking at the start of each epoch. | ||||
| callback (function): The callback funciton that will be invoked when sync_update is called. | callback (function): The callback funciton that will be invoked when sync_update is called. | ||||
| Returns: | |||||
| SyncWaitDataset, dataset added a blocking condition. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If condition name already exists. | RuntimeError: If condition name already exists. | ||||
| @@ -434,7 +440,7 @@ class Dataset: | |||||
| return a 'Dataset'. | return a 'Dataset'. | ||||
| Returns: | Returns: | ||||
| Dataset, applied by the function. | |||||
| Dataset, dataset applied by the function. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -650,7 +656,7 @@ class Dataset: | |||||
| in parallel (default=None). | in parallel (default=None). | ||||
| Returns: | Returns: | ||||
| FilterDataset, dataset filter. | |||||
| FilterDataset, dataset filtered. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -748,6 +754,9 @@ class Dataset: | |||||
| """ | """ | ||||
| Internal method called by split to calculate absolute split sizes and to | Internal method called by split to calculate absolute split sizes and to | ||||
| do some error checking after calculating absolute split sizes. | do some error checking after calculating absolute split sizes. | ||||
| Returns: | |||||
| int, absolute split sizes of the dataset. | |||||
| """ | """ | ||||
| # Call get_dataset_size here and check input here because | # Call get_dataset_size here and check input here because | ||||
| # don't want to call this once in check_split and another time in | # don't want to call this once in check_split and another time in | ||||
| @@ -1015,7 +1024,7 @@ class Dataset: | |||||
| is specified and special_first is set to default, special_tokens will be prepended | is specified and special_first is set to default, special_tokens will be prepended | ||||
| Returns: | Returns: | ||||
| Vocab node | |||||
| Vocab, vocab built from the dataset. | |||||
| Example: | Example: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -1074,7 +1083,7 @@ class Dataset: | |||||
| params(dict): contains more optional parameters of sentencepiece library | params(dict): contains more optional parameters of sentencepiece library | ||||
| Returns: | Returns: | ||||
| SentencePieceVocab node | |||||
| SentencePieceVocab, vocab built from the dataset. | |||||
| Example: | Example: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -1115,7 +1124,7 @@ class Dataset: | |||||
| return a preprogressing 'Dataset'. | return a preprogressing 'Dataset'. | ||||
| Returns: | Returns: | ||||
| Dataset, applied by the function. | |||||
| Dataset, dataset applied by the function. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -1159,7 +1168,7 @@ class Dataset: | |||||
| If device is Ascend, features of data will be transferred one by one. The limitation | If device is Ascend, features of data will be transferred one by one. The limitation | ||||
| of data transmission per time is 256M. | of data transmission per time is 256M. | ||||
| Return: | |||||
| Returns: | |||||
| TransferDataset, dataset for transferring. | TransferDataset, dataset for transferring. | ||||
| """ | """ | ||||
| return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue) | return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue) | ||||
| @@ -1287,7 +1296,7 @@ class Dataset: | |||||
| use this param to select the conversion method, only take False for better performance (default=True). | use this param to select the conversion method, only take False for better performance (default=True). | ||||
| Returns: | Returns: | ||||
| Iterator, list of ndarrays. | |||||
| TupleIterator, tuple iterator over the dataset. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -1322,7 +1331,7 @@ class Dataset: | |||||
| if output_numpy=False, iterator will output MSTensor (default=False). | if output_numpy=False, iterator will output MSTensor (default=False). | ||||
| Returns: | Returns: | ||||
| Iterator, dictionary of column name-ndarray pair. | |||||
| DictIterator, dictionary iterator over the dataset. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -1352,6 +1361,9 @@ class Dataset: | |||||
| """ | """ | ||||
| Get Input Index Information | Get Input Index Information | ||||
| Returns: | |||||
| tuple, tuple of the input index information. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| >>> | >>> | ||||
| @@ -1409,6 +1421,9 @@ class Dataset: | |||||
| def get_col_names(self): | def get_col_names(self): | ||||
| """ | """ | ||||
| Get names of the columns in the dataset | Get names of the columns in the dataset | ||||
| Returns: | |||||
| list, list of column names in the dataset. | |||||
| """ | """ | ||||
| if self._col_names is None: | if self._col_names is None: | ||||
| runtime_getter = self._init_tree_getters() | runtime_getter = self._init_tree_getters() | ||||
| @@ -1419,8 +1434,8 @@ class Dataset: | |||||
| """ | """ | ||||
| Get the shapes of output data. | Get the shapes of output data. | ||||
| Return: | |||||
| List, list of shapes of each column. | |||||
| Returns: | |||||
| list, list of shapes of each column. | |||||
| """ | """ | ||||
| if self.saved_output_shapes is None: | if self.saved_output_shapes is None: | ||||
| runtime_getter = self._init_tree_getters() | runtime_getter = self._init_tree_getters() | ||||
| @@ -1432,8 +1447,8 @@ class Dataset: | |||||
| """ | """ | ||||
| Get the types of output data. | Get the types of output data. | ||||
| Return: | |||||
| List of data types. | |||||
| Returns: | |||||
| list, list of data types. | |||||
| """ | """ | ||||
| if self.saved_output_types is None: | if self.saved_output_types is None: | ||||
| runtime_getter = self._init_tree_getters() | runtime_getter = self._init_tree_getters() | ||||
| @@ -1445,8 +1460,8 @@ class Dataset: | |||||
| """ | """ | ||||
| Get the number of batches in an epoch. | Get the number of batches in an epoch. | ||||
| Return: | |||||
| Number, number of batches. | |||||
| Returns: | |||||
| int, number of batches. | |||||
| """ | """ | ||||
| if self.dataset_size is None: | if self.dataset_size is None: | ||||
| runtime_getter = self._init_size_getter() | runtime_getter = self._init_size_getter() | ||||
| @@ -1457,8 +1472,8 @@ class Dataset: | |||||
| """ | """ | ||||
| Get the number of classes in a dataset. | Get the number of classes in a dataset. | ||||
| Return: | |||||
| Number, number of classes. | |||||
| Returns: | |||||
| int, number of classes. | |||||
| """ | """ | ||||
| if self._num_classes is None: | if self._num_classes is None: | ||||
| runtime_getter = self._init_tree_getters() | runtime_getter = self._init_tree_getters() | ||||
| @@ -1511,8 +1526,8 @@ class Dataset: | |||||
| """ | """ | ||||
| Get the size of a batch. | Get the size of a batch. | ||||
| Return: | |||||
| Number, the number of data in a batch. | |||||
| Returns: | |||||
| int, the number of data in a batch. | |||||
| """ | """ | ||||
| if self._batch_size is None: | if self._batch_size is None: | ||||
| runtime_getter = self._init_tree_getters() | runtime_getter = self._init_tree_getters() | ||||
| @@ -1525,8 +1540,8 @@ class Dataset: | |||||
| """ | """ | ||||
| Get the replication times in RepeatDataset else 1. | Get the replication times in RepeatDataset else 1. | ||||
| Return: | |||||
| Number, the count of repeat. | |||||
| Returns: | |||||
| int, the count of repeat. | |||||
| """ | """ | ||||
| if self._repeat_count is None: | if self._repeat_count is None: | ||||
| runtime_getter = self._init_tree_getters() | runtime_getter = self._init_tree_getters() | ||||
| @@ -1540,8 +1555,8 @@ class Dataset: | |||||
| Get the class index. | Get the class index. | ||||
| Returns: | Returns: | ||||
| Dict, A str-to-int mapping from label name to index. | |||||
| Dict, A str-to-list<int> mapping from label name to index for Coco ONLY. The second number | |||||
| dict, a str-to-int mapping from label name to index. | |||||
| dict, a str-to-list<int> mapping from label name to index for Coco ONLY. The second number | |||||
| in the list is used to indicate the super category | in the list is used to indicate the super category | ||||
| """ | """ | ||||
| if self.children: | if self.children: | ||||
| @@ -1588,7 +1603,7 @@ class SourceDataset(Dataset): | |||||
| patterns (Union[str, list[str]]): String or list of patterns to be searched. | patterns (Union[str, list[str]]): String or list of patterns to be searched. | ||||
| Returns: | Returns: | ||||
| List, files. | |||||
| list, list of files. | |||||
| """ | """ | ||||
| if not isinstance(patterns, list): | if not isinstance(patterns, list): | ||||
| @@ -1646,9 +1661,6 @@ class MappableDataset(SourceDataset): | |||||
| Args: | Args: | ||||
| new_sampler (Sampler): The sampler to use for the current dataset. | new_sampler (Sampler): The sampler to use for the current dataset. | ||||
| Returns: | |||||
| Dataset, that uses new_sampler. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| >>> | >>> | ||||
| @@ -1909,8 +1921,9 @@ class BatchDataset(Dataset): | |||||
| Args: | Args: | ||||
| dataset (Dataset): Dataset to be checked. | dataset (Dataset): Dataset to be checked. | ||||
| Return: | |||||
| True or False. | |||||
| Returns: | |||||
| bool, whether repeat is used before batch. | |||||
| """ | """ | ||||
| if isinstance(dataset, RepeatDataset): | if isinstance(dataset, RepeatDataset): | ||||
| return True | return True | ||||
| @@ -1995,18 +2008,12 @@ class BatchInfo(cde.CBatchInfo): | |||||
| def get_batch_num(self): | def get_batch_num(self): | ||||
| """ | """ | ||||
| Return the batch number of the current batch. | Return the batch number of the current batch. | ||||
| Return: | |||||
| Number, number of the current batch. | |||||
| """ | """ | ||||
| return | return | ||||
| def get_epoch_num(self): | def get_epoch_num(self): | ||||
| """ | """ | ||||
| Return the epoch number of the current batch. | Return the epoch number of the current batch. | ||||
| Return: | |||||
| Number, number of the current epoch. | |||||
| """ | """ | ||||
| return | return | ||||
| @@ -2055,8 +2062,8 @@ class BlockReleasePair: | |||||
| """ | """ | ||||
| Function for handing blocking condition. | Function for handing blocking condition. | ||||
| Return: | |||||
| True | |||||
| Returns: | |||||
| bool, True. | |||||
| """ | """ | ||||
| with self.cv: | with self.cv: | ||||
| # if disable is true, the always evaluate to true | # if disable is true, the always evaluate to true | ||||
| @@ -2145,8 +2152,9 @@ class SyncWaitDataset(Dataset): | |||||
| Args: | Args: | ||||
| dataset (Dataset): Dataset to be checked. | dataset (Dataset): Dataset to be checked. | ||||
| Return: | |||||
| True or False. | |||||
| Returns: | |||||
| bool, whether sync_wait is used before batch. | |||||
| """ | """ | ||||
| if isinstance(dataset, BatchDataset): | if isinstance(dataset, BatchDataset): | ||||
| return True | return True | ||||
| @@ -2932,6 +2940,9 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, n | |||||
| num_shards (int): Number of shard for sharding. | num_shards (int): Number of shard for sharding. | ||||
| shard_id (int): Shard ID. | shard_id (int): Shard ID. | ||||
| non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False). | non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False). | ||||
| Returns: | |||||
| Sampler, sampler selected based on user input. | |||||
| """ | """ | ||||
| if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]): | if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]): | ||||
| return None | return None | ||||
| @@ -4180,7 +4191,7 @@ class ManifestDataset(MappableDataset): | |||||
| Get the class index. | Get the class index. | ||||
| Returns: | Returns: | ||||
| Dict, A str-to-int mapping from label name to index. | |||||
| dict, a str-to-int mapping from label name to index. | |||||
| """ | """ | ||||
| if self.class_indexing is None: | if self.class_indexing is None: | ||||
| if self._class_indexing is None: | if self._class_indexing is None: | ||||
| @@ -4579,7 +4590,7 @@ class Schema: | |||||
| Args: | Args: | ||||
| schema_file(str): Path of schema file (default=None). | schema_file(str): Path of schema file (default=None). | ||||
| Return: | |||||
| Returns: | |||||
| Schema object, schema info about dataset. | Schema object, schema info about dataset. | ||||
| Raises: | Raises: | ||||
| @@ -4654,7 +4665,7 @@ class Schema: | |||||
| Get a JSON string of the schema. | Get a JSON string of the schema. | ||||
| Returns: | Returns: | ||||
| Str, JSON string of the schema. | |||||
| str, JSON string of the schema. | |||||
| """ | """ | ||||
| return self.cpp_schema.to_json() | return self.cpp_schema.to_json() | ||||
| @@ -4840,7 +4851,7 @@ class VOCDataset(MappableDataset): | |||||
| Get the class index. | Get the class index. | ||||
| Returns: | Returns: | ||||
| Dict, A str-to-int mapping from label name to index. | |||||
| dict, a str-to-int mapping from label name to index. | |||||
| """ | """ | ||||
| if self.task != "Detection": | if self.task != "Detection": | ||||
| raise NotImplementedError("Only 'Detection' support get_class_indexing.") | raise NotImplementedError("Only 'Detection' support get_class_indexing.") | ||||
| @@ -5032,7 +5043,7 @@ class CocoDataset(MappableDataset): | |||||
| Get the class index. | Get the class index. | ||||
| Returns: | Returns: | ||||
| Dict, A str-to-list<int> mapping from label name to index | |||||
| dict, a str-to-list<int> mapping from label name to index | |||||
| """ | """ | ||||
| if self.task not in {"Detection", "Panoptic"}: | if self.task not in {"Detection", "Panoptic"}: | ||||
| raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.") | raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.") | ||||
| @@ -100,7 +100,7 @@ class GraphData: | |||||
| node_type (int): Specify the type of node. | node_type (int): Specify the type of node. | ||||
| Returns: | Returns: | ||||
| numpy.ndarray: Array of nodes. | |||||
| numpy.ndarray, array of nodes. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -124,7 +124,7 @@ class GraphData: | |||||
| edge_type (int): Specify the type of edge. | edge_type (int): Specify the type of edge. | ||||
| Returns: | Returns: | ||||
| numpy.ndarray: array of edges. | |||||
| numpy.ndarray, array of edges. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -148,7 +148,7 @@ class GraphData: | |||||
| edge_list (Union[list, numpy.ndarray]): The given list of edges. | edge_list (Union[list, numpy.ndarray]): The given list of edges. | ||||
| Returns: | Returns: | ||||
| numpy.ndarray: Array of nodes. | |||||
| numpy.ndarray, array of nodes. | |||||
| Raises: | Raises: | ||||
| TypeError: If `edge_list` is not list or ndarray. | TypeError: If `edge_list` is not list or ndarray. | ||||
| @@ -167,7 +167,7 @@ class GraphData: | |||||
| neighbor_type (int): Specify the type of neighbor. | neighbor_type (int): Specify the type of neighbor. | ||||
| Returns: | Returns: | ||||
| numpy.ndarray: Array of nodes. | |||||
| numpy.ndarray, array of neighbors. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -201,7 +201,7 @@ class GraphData: | |||||
| neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop. | neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop. | ||||
| Returns: | Returns: | ||||
| numpy.ndarray: Array of nodes. | |||||
| numpy.ndarray, array of neighbors. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -231,7 +231,7 @@ class GraphData: | |||||
| neg_neighbor_type (int): Specify the type of negative neighbor. | neg_neighbor_type (int): Specify the type of negative neighbor. | ||||
| Returns: | Returns: | ||||
| numpy.ndarray: Array of nodes. | |||||
| numpy.ndarray, array of neighbors. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -260,7 +260,7 @@ class GraphData: | |||||
| feature_types (Union[list, numpy.ndarray]): The given list of feature types. | feature_types (Union[list, numpy.ndarray]): The given list of feature types. | ||||
| Returns: | Returns: | ||||
| numpy.ndarray: array of features. | |||||
| numpy.ndarray, array of features. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -292,7 +292,7 @@ class GraphData: | |||||
| feature_types (Union[list, numpy.ndarray]): The given list of feature types. | feature_types (Union[list, numpy.ndarray]): The given list of feature types. | ||||
| Returns: | Returns: | ||||
| numpy.ndarray: array of features. | |||||
| numpy.ndarray, array of features. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -320,7 +320,7 @@ class GraphData: | |||||
| the feature information of nodes, the number of edges, the type of edges, and the feature information of edges. | the feature information of nodes, the number of edges, the type of edges, and the feature information of edges. | ||||
| Returns: | Returns: | ||||
| dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, | |||||
| dict, meta information of the graph. The key is node_type, edge_type, node_num, edge_num, | |||||
| node_feature_type and edge_feature_type. | node_feature_type and edge_feature_type. | ||||
| """ | """ | ||||
| if self._working_mode == 'server': | if self._working_mode == 'server': | ||||
| @@ -347,7 +347,7 @@ class GraphData: | |||||
| A default value of -1 indicates that no node is given. | A default value of -1 indicates that no node is given. | ||||
| Returns: | Returns: | ||||
| numpy.ndarray: Array of nodes. | |||||
| numpy.ndarray, array of nodes. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -128,6 +128,7 @@ class BuiltinSampler: | |||||
| User should not extend this class. | User should not extend this class. | ||||
| """ | """ | ||||
| def __init__(self, num_samples=None): | def __init__(self, num_samples=None): | ||||
| self.child_sampler = None | self.child_sampler = None | ||||
| self.num_samples = num_samples | self.num_samples = num_samples | ||||
| @@ -201,7 +202,7 @@ class BuiltinSampler: | |||||
| - None | - None | ||||
| Returns: | Returns: | ||||
| int, The number of samples, or None | |||||
| int, the number of samples, or None | |||||
| """ | """ | ||||
| if self.child_sampler is not None: | if self.child_sampler is not None: | ||||
| child_samples = self.child_sampler.get_num_samples() | child_samples = self.child_sampler.get_num_samples() | ||||
| @@ -1063,8 +1063,9 @@ class LinearTransformation: | |||||
| the dot product with the transformation matrix, and reshapes it back to its original shape. | the dot product with the transformation matrix, and reshapes it back to its original shape. | ||||
| Args: | Args: | ||||
| transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W. | |||||
| mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W. | |||||
| transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), where | |||||
| :math:`D = C \times H \times W`. | |||||
| mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where :math:`D = C \times H \times W`. | |||||
| Examples: | Examples: | ||||
| >>> from mindspore.dataset.transforms.py_transforms import Compose | >>> from mindspore.dataset.transforms.py_transforms import Compose | ||||
| @@ -23,7 +23,7 @@ from PIL import Image | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log | from mindspore import log | ||||
| from mindspore.dataset.engine.datasets import Dataset | |||||
| from mindspore.dataset import Dataset | |||||
| from mindspore.nn import Cell, SequentialCell | from mindspore.nn import Cell, SequentialCell | ||||
| from mindspore.ops.operations import ExpandDims | from mindspore.ops.operations import ExpandDims | ||||
| from mindspore.train._utils import check_value_type | from mindspore.train._utils import check_value_type | ||||
| @@ -30,6 +30,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchema | |||||
| __all__ = ['FileWriter'] | __all__ = ['FileWriter'] | ||||
| class FileWriter: | class FileWriter: | ||||
| """ | """ | ||||
| Class to write user defined raw data into MindRecord File series. | Class to write user defined raw data into MindRecord File series. | ||||
| @@ -45,6 +46,7 @@ class FileWriter: | |||||
| Raises: | Raises: | ||||
| ParamValueError: If `file_name` or `shard_num` is invalid. | ParamValueError: If `file_name` or `shard_num` is invalid. | ||||
| """ | """ | ||||
| def __init__(self, file_name, shard_num=1): | def __init__(self, file_name, shard_num=1): | ||||
| check_filename(file_name) | check_filename(file_name) | ||||
| self._file_name = file_name | self._file_name = file_name | ||||
| @@ -84,7 +86,7 @@ class FileWriter: | |||||
| file_name (str): String of MindRecord file name. | file_name (str): String of MindRecord file name. | ||||
| Returns: | Returns: | ||||
| Instance of FileWriter. | |||||
| FileWriter, file writer for the opened MindRecord file. | |||||
| Raises: | Raises: | ||||
| ParamValueError: If file_name is invalid. | ParamValueError: If file_name is invalid. | ||||
| @@ -118,7 +120,7 @@ class FileWriter: | |||||
| desc (str, optional): String of schema description (default=None). | desc (str, optional): String of schema description (default=None). | ||||
| Returns: | Returns: | ||||
| An integer, schema id. | |||||
| int, schema id. | |||||
| Raises: | Raises: | ||||
| MRMInvalidSchemaError: If schema is invalid. | MRMInvalidSchemaError: If schema is invalid. | ||||
| @@ -175,17 +177,17 @@ class FileWriter: | |||||
| if field not in v: | if field not in v: | ||||
| error_data_dic[i] = "for schema, {} th data is wrong, " \ | error_data_dic[i] = "for schema, {} th data is wrong, " \ | ||||
| "there is not '{}' object in the raw data.".format(i, field) | |||||
| "there is not '{}' object in the raw data.".format(i, field) | |||||
| continue | continue | ||||
| field_type = type(v[field]).__name__ | field_type = type(v[field]).__name__ | ||||
| if field_type not in VALUE_TYPE_MAP: | if field_type not in VALUE_TYPE_MAP: | ||||
| error_data_dic[i] = "for schema, {} th data is wrong, " \ | error_data_dic[i] = "for schema, {} th data is wrong, " \ | ||||
| "data type for '{}' is not matched.".format(i, field) | |||||
| "data type for '{}' is not matched.".format(i, field) | |||||
| continue | continue | ||||
| if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]: | if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]: | ||||
| error_data_dic[i] = "for schema, {} th data is wrong, " \ | error_data_dic[i] = "for schema, {} th data is wrong, " \ | ||||
| "data type for '{}' is not matched.".format(i, field) | |||||
| "data type for '{}' is not matched.".format(i, field) | |||||
| continue | continue | ||||
| if field_type == 'ndarray': | if field_type == 'ndarray': | ||||
| @@ -206,7 +208,6 @@ class FileWriter: | |||||
| def open_and_set_header(self): | def open_and_set_header(self): | ||||
| """ | """ | ||||
| Open writer and set header. | Open writer and set header. | ||||
| """ | """ | ||||
| if not self._writer.is_open: | if not self._writer.is_open: | ||||
| self._writer.open(self._paths) | self._writer.open(self._paths) | ||||
| @@ -222,6 +223,9 @@ class FileWriter: | |||||
| raw_data (list[dict]): List of raw data. | raw_data (list[dict]): List of raw data. | ||||
| parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). | parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). | ||||
| Returns: | |||||
| MSRStatus, SUCCESS or FAILED. | |||||
| Raises: | Raises: | ||||
| ParamTypeError: If index field is invalid. | ParamTypeError: If index field is invalid. | ||||
| MRMOpenError: If failed to open MindRecord File. | MRMOpenError: If failed to open MindRecord File. | ||||
| @@ -330,7 +334,7 @@ class FileWriter: | |||||
| v (dict): Sub dict in schema | v (dict): Sub dict in schema | ||||
| Returns: | Returns: | ||||
| bool, True or False. | |||||
| bool, whether the array item is valid. | |||||
| str, error message. | str, error message. | ||||
| """ | """ | ||||
| if v['type'] not in VALID_ARRAY_ATTRIBUTES: | if v['type'] not in VALID_ARRAY_ATTRIBUTES: | ||||
| @@ -355,7 +359,7 @@ class FileWriter: | |||||
| content (dict): Dict of raw schema. | content (dict): Dict of raw schema. | ||||
| Returns: | Returns: | ||||
| bool, True or False. | |||||
| bool, whether the schema is valid. | |||||
| str, error message. | str, error message. | ||||
| """ | """ | ||||
| error = '' | error = '' | ||||
| @@ -23,6 +23,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMDefineCategor | |||||
| __all__ = ['MindPage'] | __all__ = ['MindPage'] | ||||
| class MindPage: | class MindPage: | ||||
| """ | """ | ||||
| Class to read MindRecord File series in pagination. | Class to read MindRecord File series in pagination. | ||||
| @@ -36,6 +37,7 @@ class MindPage: | |||||
| ParamValueError: If `file_name`, `num_consumer` or columns is invalid. | ParamValueError: If `file_name`, `num_consumer` or columns is invalid. | ||||
| MRMInitSegmentError: If failed to initialize ShardSegment. | MRMInitSegmentError: If failed to initialize ShardSegment. | ||||
| """ | """ | ||||
| def __init__(self, file_name, num_consumer=4): | def __init__(self, file_name, num_consumer=4): | ||||
| if isinstance(file_name, list): | if isinstance(file_name, list): | ||||
| for f in file_name: | for f in file_name: | ||||
| @@ -69,7 +71,12 @@ class MindPage: | |||||
| return self._candidate_fields | return self._candidate_fields | ||||
| def get_category_fields(self): | def get_category_fields(self): | ||||
| """Return candidate category fields.""" | |||||
| """ | |||||
| Return candidate category fields. | |||||
| Returns: | |||||
| list[str], by which data could be grouped. | |||||
| """ | |||||
| logger.warning("WARN_DEPRECATED: The usage of get_category_fields is deprecated." | logger.warning("WARN_DEPRECATED: The usage of get_category_fields is deprecated." | ||||
| " Please use candidate_fields") | " Please use candidate_fields") | ||||
| return self.candidate_fields | return self.candidate_fields | ||||
| @@ -97,12 +104,22 @@ class MindPage: | |||||
| @property | @property | ||||
| def category_field(self): | def category_field(self): | ||||
| """Getter function for category fields.""" | |||||
| """ | |||||
| Getter function for category fields. | |||||
| Returns: | |||||
| list[str], by which data could be grouped. | |||||
| """ | |||||
| return self._category_field | return self._category_field | ||||
| @category_field.setter | @category_field.setter | ||||
| def category_field(self, category_field): | def category_field(self, category_field): | ||||
| """Setter function for category field""" | |||||
| """ | |||||
| Setter function for category field. | |||||
| Returns: | |||||
| MSRStatus, SUCCESS or FAILED. | |||||
| """ | |||||
| if not category_field or not isinstance(category_field, str): | if not category_field or not isinstance(category_field, str): | ||||
| raise ParamTypeError('category_fields', 'str') | raise ParamTypeError('category_fields', 'str') | ||||
| if category_field not in self._candidate_fields: | if category_field not in self._candidate_fields: | ||||
| @@ -132,7 +149,7 @@ class MindPage: | |||||
| num_row (int): Number of rows in a page. | num_row (int): Number of rows in a page. | ||||
| Returns: | Returns: | ||||
| List, list[dict]. | |||||
| list[dict], data queried by category id. | |||||
| Raises: | Raises: | ||||
| ParamValueError: If any parameter is invalid. | ParamValueError: If any parameter is invalid. | ||||
| @@ -158,7 +175,7 @@ class MindPage: | |||||
| num_row (int): Number of row in a page. | num_row (int): Number of row in a page. | ||||
| Returns: | Returns: | ||||
| str, read at page. | |||||
| list[dict], data queried by category name. | |||||
| """ | """ | ||||
| if not isinstance(category_name, str): | if not isinstance(category_name, str): | ||||
| raise ParamValueError("Category name should be str.") | raise ParamValueError("Category name should be str.") | ||||
| @@ -23,6 +23,7 @@ from .common.exceptions import MRMOpenError, MRMOpenForAppendError, MRMInvalidHe | |||||
| __all__ = ['ShardWriter'] | __all__ = ['ShardWriter'] | ||||
| class ShardWriter: | class ShardWriter: | ||||
| """ | """ | ||||
| Wrapper class which is represent shardWrite class in c++ module. | Wrapper class which is represent shardWrite class in c++ module. | ||||
| @@ -192,9 +193,11 @@ class ShardWriter: | |||||
| if len(blob_data) == 1: | if len(blob_data) == 1: | ||||
| values = [v for v in blob_data.values()] | values = [v for v in blob_data.values()] | ||||
| return bytes(values[0]) | return bytes(values[0]) | ||||
| # convert int to bytes | # convert int to bytes | ||||
| def int_to_bytes(x: int) -> bytes: | def int_to_bytes(x: int) -> bytes: | ||||
| return x.to_bytes(8, 'big') | return x.to_bytes(8, 'big') | ||||
| merged = bytes() | merged = bytes() | ||||
| for field, v in blob_data.items(): | for field, v in blob_data.items(): | ||||
| # convert ndarray to bytes | # convert ndarray to bytes | ||||
| @@ -209,7 +212,7 @@ class ShardWriter: | |||||
| Flush data to disk. | Flush data to disk. | ||||
| Returns: | Returns: | ||||
| Class MSRStatus, SUCCESS or FAILED. | |||||
| MSRStatus, SUCCESS or FAILED. | |||||
| Raises: | Raises: | ||||
| MRMCommitError: If failed to flush data to disk. | MRMCommitError: If failed to flush data to disk. | ||||
| @@ -33,6 +33,7 @@ except ModuleNotFoundError: | |||||
| __all__ = ['Cifar100ToMR'] | __all__ = ['Cifar100ToMR'] | ||||
| class Cifar100ToMR: | class Cifar100ToMR: | ||||
| """ | """ | ||||
| A class to transform from cifar100 to MindRecord. | A class to transform from cifar100 to MindRecord. | ||||
| @@ -44,6 +45,7 @@ class Cifar100ToMR: | |||||
| Raises: | Raises: | ||||
| ValueError: If source or destination is invalid. | ValueError: If source or destination is invalid. | ||||
| """ | """ | ||||
| def __init__(self, source, destination): | def __init__(self, source, destination): | ||||
| check_filename(source) | check_filename(source) | ||||
| self.source = source | self.source = source | ||||
| @@ -74,7 +76,7 @@ class Cifar100ToMR: | |||||
| fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"]. | fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"]. | ||||
| Returns: | Returns: | ||||
| SUCCESS or FAILED, whether cifar100 is successfully transformed to MindRecord. | |||||
| MSRStatus, whether cifar100 is successfully transformed to MindRecord. | |||||
| """ | """ | ||||
| if fields and not isinstance(fields, list): | if fields and not isinstance(fields, list): | ||||
| raise ValueError("The parameter fields should be None or list") | raise ValueError("The parameter fields should be None or list") | ||||
| @@ -114,6 +116,7 @@ class Cifar100ToMR: | |||||
| raise t.exception | raise t.exception | ||||
| return t.res | return t.res | ||||
| def _construct_raw_data(images, fine_labels, coarse_labels): | def _construct_raw_data(images, fine_labels, coarse_labels): | ||||
| """ | """ | ||||
| Construct raw data from cifar100 data. | Construct raw data from cifar100 data. | ||||
| @@ -124,7 +127,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels): | |||||
| coarse_labels (list): coarse label list from cifar100. | coarse_labels (list): coarse label list from cifar100. | ||||
| Returns: | Returns: | ||||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||||
| list[dict], data dictionary constructed from cifar100. | |||||
| """ | """ | ||||
| if not cv2: | if not cv2: | ||||
| raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") | raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") | ||||
| @@ -141,6 +144,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels): | |||||
| raw_data.append(row_data) | raw_data.append(row_data) | ||||
| return raw_data | return raw_data | ||||
| def _generate_mindrecord(file_name, raw_data, fields, schema_desc): | def _generate_mindrecord(file_name, raw_data, fields, schema_desc): | ||||
| """ | """ | ||||
| Generate MindRecord file from raw data. | Generate MindRecord file from raw data. | ||||
| @@ -153,7 +157,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc): | |||||
| schema_desc (str): String of schema description. | schema_desc (str): String of schema description. | ||||
| Returns: | Returns: | ||||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||||
| MSRStatus, whether successfully written into MindRecord. | |||||
| """ | """ | ||||
| schema = {"id": {"type": "int64"}, "fine_label": {"type": "int64"}, | schema = {"id": {"type": "int64"}, "fine_label": {"type": "int64"}, | ||||
| "coarse_label": {"type": "int64"}, "data": {"type": "bytes"}} | "coarse_label": {"type": "int64"}, "data": {"type": "bytes"}} | ||||
| @@ -25,6 +25,7 @@ from .cifar10 import Cifar10 | |||||
| from ..common.exceptions import PathNotExistsError | from ..common.exceptions import PathNotExistsError | ||||
| from ..filewriter import FileWriter | from ..filewriter import FileWriter | ||||
| from ..shardutils import check_filename, ExceptionThread, SUCCESS, FAILED | from ..shardutils import check_filename, ExceptionThread, SUCCESS, FAILED | ||||
| try: | try: | ||||
| cv2 = import_module("cv2") | cv2 = import_module("cv2") | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| @@ -32,6 +33,7 @@ except ModuleNotFoundError: | |||||
| __all__ = ['Cifar10ToMR'] | __all__ = ['Cifar10ToMR'] | ||||
| class Cifar10ToMR: | class Cifar10ToMR: | ||||
| """ | """ | ||||
| A class to transform from cifar10 to MindRecord. | A class to transform from cifar10 to MindRecord. | ||||
| @@ -43,6 +45,7 @@ class Cifar10ToMR: | |||||
| Raises: | Raises: | ||||
| ValueError: If source or destination is invalid. | ValueError: If source or destination is invalid. | ||||
| """ | """ | ||||
| def __init__(self, source, destination): | def __init__(self, source, destination): | ||||
| check_filename(source) | check_filename(source) | ||||
| self.source = source | self.source = source | ||||
| @@ -73,7 +76,7 @@ class Cifar10ToMR: | |||||
| fields (list[str], optional): A list of index fields, e.g.["label"] (default=None). | fields (list[str], optional): A list of index fields, e.g.["label"] (default=None). | ||||
| Returns: | Returns: | ||||
| SUCCESS or FAILED, whether cifar10 is successfully transformed to MindRecord. | |||||
| MSRStatus, whether cifar10 is successfully transformed to MindRecord. | |||||
| """ | """ | ||||
| if fields and not isinstance(fields, list): | if fields and not isinstance(fields, list): | ||||
| raise ValueError("The parameter fields should be None or list") | raise ValueError("The parameter fields should be None or list") | ||||
| @@ -109,6 +112,7 @@ class Cifar10ToMR: | |||||
| raise t.exception | raise t.exception | ||||
| return t.res | return t.res | ||||
| def _construct_raw_data(images, labels): | def _construct_raw_data(images, labels): | ||||
| """ | """ | ||||
| Construct raw data from cifar10 data. | Construct raw data from cifar10 data. | ||||
| @@ -118,7 +122,7 @@ def _construct_raw_data(images, labels): | |||||
| labels (list): label list from cifar10. | labels (list): label list from cifar10. | ||||
| Returns: | Returns: | ||||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||||
| list[dict], data dictionary constructed from cifar10. | |||||
| """ | """ | ||||
| if not cv2: | if not cv2: | ||||
| raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") | raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") | ||||
| @@ -133,6 +137,7 @@ def _construct_raw_data(images, labels): | |||||
| raw_data.append(row_data) | raw_data.append(row_data) | ||||
| return raw_data | return raw_data | ||||
| def _generate_mindrecord(file_name, raw_data, fields, schema_desc): | def _generate_mindrecord(file_name, raw_data, fields, schema_desc): | ||||
| """ | """ | ||||
| Generate MindRecord file from raw data. | Generate MindRecord file from raw data. | ||||
| @@ -145,7 +150,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc): | |||||
| schema_desc (str): String of schema description. | schema_desc (str): String of schema description. | ||||
| Returns: | Returns: | ||||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||||
| MSRStatus, whether successfully written into MindRecord. | |||||
| """ | """ | ||||
| schema = {"id": {"type": "int64"}, "label": {"type": "int64"}, | schema = {"id": {"type": "int64"}, "label": {"type": "int64"}, | ||||
| "data": {"type": "bytes"}} | "data": {"type": "bytes"}} | ||||
| @@ -29,6 +29,7 @@ except ModuleNotFoundError: | |||||
| __all__ = ['CsvToMR'] | __all__ = ['CsvToMR'] | ||||
| class CsvToMR: | class CsvToMR: | ||||
| """ | """ | ||||
| A class to transform from csv to MindRecord. | A class to transform from csv to MindRecord. | ||||
| @@ -121,7 +122,7 @@ class CsvToMR: | |||||
| Executes transformation from csv to MindRecord. | Executes transformation from csv to MindRecord. | ||||
| Returns: | Returns: | ||||
| SUCCESS or FAILED, whether csv is successfully transformed to MindRecord. | |||||
| MSRStatus, whether csv is successfully transformed to MindRecord. | |||||
| """ | """ | ||||
| if not os.path.exists(self.source): | if not os.path.exists(self.source): | ||||
| raise IOError("Csv file {} do not exist.".format(self.source)) | raise IOError("Csv file {} do not exist.".format(self.source)) | ||||
| @@ -47,6 +47,7 @@ class ImageNetToMR: | |||||
| Raises: | Raises: | ||||
| ValueError: If `map_file`, `image_dir` or `destination` is invalid. | ValueError: If `map_file`, `image_dir` or `destination` is invalid. | ||||
| """ | """ | ||||
| def __init__(self, map_file, image_dir, destination, partition_number=1): | def __init__(self, map_file, image_dir, destination, partition_number=1): | ||||
| check_filename(map_file) | check_filename(map_file) | ||||
| self.map_file = map_file | self.map_file = map_file | ||||
| @@ -122,7 +123,7 @@ class ImageNetToMR: | |||||
| Executes transformation from imagenet to MindRecord. | Executes transformation from imagenet to MindRecord. | ||||
| Returns: | Returns: | ||||
| SUCCESS or FAILED, whether imagenet is successfully transformed to MindRecord. | |||||
| MSRStatus, whether imagenet is successfully transformed to MindRecord. | |||||
| """ | """ | ||||
| t0_total = time.time() | t0_total = time.time() | ||||
| @@ -133,10 +134,10 @@ class ImageNetToMR: | |||||
| logger.info("transformed MindRecord schema is: {}".format(imagenet_schema_json)) | logger.info("transformed MindRecord schema is: {}".format(imagenet_schema_json)) | ||||
| # set the header size | # set the header size | ||||
| self.writer.set_header_size(1<<24) | |||||
| self.writer.set_header_size(1 << 24) | |||||
| # set the page size | # set the page size | ||||
| self.writer.set_page_size(1<<26) | |||||
| self.writer.set_page_size(1 << 26) | |||||
| # create the schema | # create the schema | ||||
| self.writer.add_schema(imagenet_schema_json, "imagenet_schema") | self.writer.add_schema(imagenet_schema_json, "imagenet_schema") | ||||
| @@ -32,6 +32,7 @@ except ModuleNotFoundError: | |||||
| __all__ = ['MnistToMR'] | __all__ = ['MnistToMR'] | ||||
| class MnistToMR: | class MnistToMR: | ||||
| """ | """ | ||||
| A class to transform from Mnist to MindRecord. | A class to transform from Mnist to MindRecord. | ||||
| @@ -125,7 +126,7 @@ class MnistToMR: | |||||
| Executes transformation from Mnist train part to MindRecord. | Executes transformation from Mnist train part to MindRecord. | ||||
| Returns: | Returns: | ||||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||||
| MSRStatus, whether successfully written into MindRecord. | |||||
| """ | """ | ||||
| t0_total = time.time() | t0_total = time.time() | ||||
| @@ -173,7 +174,7 @@ class MnistToMR: | |||||
| Executes transformation from Mnist test part to MindRecord. | Executes transformation from Mnist test part to MindRecord. | ||||
| Returns: | Returns: | ||||
| SUCCESS or FAILED, whether Mnist is successfully transformed to MindRecord. | |||||
| MSRStatus, whether Mnist is successfully transformed to MindRecord. | |||||
| """ | """ | ||||
| t0_total = time.time() | t0_total = time.time() | ||||
| @@ -222,7 +223,7 @@ class MnistToMR: | |||||
| Executes transformation from Mnist to MindRecord. | Executes transformation from Mnist to MindRecord. | ||||
| Returns: | Returns: | ||||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||||
| MSRStatus, whether successfully written into MindRecord. | |||||
| """ | """ | ||||
| if not cv2: | if not cv2: | ||||
| raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") | raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") | ||||
| @@ -23,7 +23,6 @@ from mindspore import log as logger | |||||
| from ..filewriter import FileWriter | from ..filewriter import FileWriter | ||||
| from ..shardutils import check_filename, ExceptionThread | from ..shardutils import check_filename, ExceptionThread | ||||
| __all__ = ['TFRecordToMR'] | __all__ = ['TFRecordToMR'] | ||||
| SupportedTensorFlowVersion = '1.13.0-rc1' | SupportedTensorFlowVersion = '1.13.0-rc1' | ||||
| @@ -86,9 +85,10 @@ class TFRecordToMR: | |||||
| ValueError: If parameter is invalid. | ValueError: If parameter is invalid. | ||||
| Exception: when tensorflow module is not found or version is not correct. | Exception: when tensorflow module is not found or version is not correct. | ||||
| """ | """ | ||||
| def __init__(self, source, destination, feature_dict, bytes_fields=None): | def __init__(self, source, destination, feature_dict, bytes_fields=None): | ||||
| try: | try: | ||||
| self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord | |||||
| self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord | |||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| self.tf = None | self.tf = None | ||||
| if not self.tf: | if not self.tf: | ||||
| @@ -265,7 +265,7 @@ class TFRecordToMR: | |||||
| Execute transformation from TFRecord to MindRecord. | Execute transformation from TFRecord to MindRecord. | ||||
| Returns: | Returns: | ||||
| SUCCESS or FAILED, whether TFRecord is successfuly transformed to MindRecord. | |||||
| MSRStatus, whether TFRecord is successfuly transformed to MindRecord. | |||||
| """ | """ | ||||
| writer = FileWriter(self.destination) | writer = FileWriter(self.destination) | ||||
| logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}" | logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}" | ||||
| @@ -22,13 +22,12 @@ import pickle | |||||
| import numpy as np | import numpy as np | ||||
| import pandas as pd | import pandas as pd | ||||
| from mindspore.dataset.engine import GeneratorDataset | |||||
| from mindspore.dataset import GeneratorDataset | |||||
| import src.constants as rconst | import src.constants as rconst | ||||
| import src.movielens as movielens | import src.movielens as movielens | ||||
| import src.stat_utils as stat_utils | import src.stat_utils as stat_utils | ||||
| DATASET_TO_NUM_USERS_AND_ITEMS = { | DATASET_TO_NUM_USERS_AND_ITEMS = { | ||||
| "ml-1m": (6040, 3706), | "ml-1m": (6040, 3706), | ||||
| "ml-20m": (138493, 26744) | "ml-20m": (138493, 26744) | ||||
| @@ -205,6 +204,7 @@ class NCFDataset: | |||||
| """ | """ | ||||
| A dataset for NCF network. | A dataset for NCF network. | ||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| pos_users, | pos_users, | ||||
| pos_items, | pos_items, | ||||
| @@ -407,6 +407,7 @@ class RandomSampler: | |||||
| """ | """ | ||||
| A random sampler for dataset. | A random sampler for dataset. | ||||
| """ | """ | ||||
| def __init__(self, pos_count, num_train_negatives, batch_size): | def __init__(self, pos_count, num_train_negatives, batch_size): | ||||
| self.pos_count = pos_count | self.pos_count = pos_count | ||||
| self._num_samples = (1 + num_train_negatives) * self.pos_count | self._num_samples = (1 + num_train_negatives) * self.pos_count | ||||
| @@ -433,6 +434,7 @@ class DistributedSamplerOfTrain: | |||||
| """ | """ | ||||
| A distributed sampler for dataset. | A distributed sampler for dataset. | ||||
| """ | """ | ||||
| def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size): | def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size): | ||||
| """ | """ | ||||
| Distributed sampler of training dataset. | Distributed sampler of training dataset. | ||||
| @@ -443,15 +445,16 @@ class DistributedSamplerOfTrain: | |||||
| self._batch_size = batch_size | self._batch_size = batch_size | ||||
| self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size)) | self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size)) | ||||
| self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size)) | |||||
| self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size)) | |||||
| self._total_num_samples = self._samples_per_rank * self._rank_size | self._total_num_samples = self._samples_per_rank * self._rank_size | ||||
| def __iter__(self): | def __iter__(self): | ||||
| """ | """ | ||||
| Returns the data after each sampling. | Returns the data after each sampling. | ||||
| """ | """ | ||||
| indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32())) | indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32())) | ||||
| indices = indices.tolist() | indices = indices.tolist() | ||||
| indices.extend(indices[:self._total_num_samples-len(indices)]) | |||||
| indices.extend(indices[:self._total_num_samples - len(indices)]) | |||||
| indices = indices[self._rank_id:self._total_num_samples:self._rank_size] | indices = indices[self._rank_id:self._total_num_samples:self._rank_size] | ||||
| batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)] | batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)] | ||||
| @@ -463,10 +466,12 @@ class DistributedSamplerOfTrain: | |||||
| """ | """ | ||||
| return self._batchs_per_rank | return self._batchs_per_rank | ||||
| class SequenceSampler: | class SequenceSampler: | ||||
| """ | """ | ||||
| A sequence sampler for dataset. | A sequence sampler for dataset. | ||||
| """ | """ | ||||
| def __init__(self, eval_batch_size, num_users): | def __init__(self, eval_batch_size, num_users): | ||||
| self._eval_users_per_batch = int( | self._eval_users_per_batch = int( | ||||
| eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) | eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) | ||||
| @@ -491,10 +496,12 @@ class SequenceSampler: | |||||
| """ | """ | ||||
| return self._eval_batches_per_epoch | return self._eval_batches_per_epoch | ||||
| class DistributedSamplerOfEval: | class DistributedSamplerOfEval: | ||||
| """ | """ | ||||
| A distributed sampler for eval dataset. | A distributed sampler for eval dataset. | ||||
| """ | """ | ||||
| def __init__(self, eval_batch_size, num_users, rank_id, rank_size): | def __init__(self, eval_batch_size, num_users, rank_id, rank_size): | ||||
| self._eval_users_per_batch = int( | self._eval_users_per_batch = int( | ||||
| eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) | eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) | ||||
| @@ -507,8 +514,8 @@ class DistributedSamplerOfEval: | |||||
| self._eval_batch_size = eval_batch_size | self._eval_batch_size = eval_batch_size | ||||
| self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size)) | self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size)) | ||||
| #self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size)) | |||||
| #self._total_num_samples = self._samples_per_rank * self._rank_size | |||||
| # self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size)) | |||||
| # self._total_num_samples = self._samples_per_rank * self._rank_size | |||||
| def __iter__(self): | def __iter__(self): | ||||
| indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch) | indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch) | ||||
| @@ -525,6 +532,7 @@ class DistributedSamplerOfEval: | |||||
| def __len__(self): | def __len__(self): | ||||
| return self._batchs_per_rank | return self._batchs_per_rank | ||||
| def parse_eval_batch_size(eval_batch_size): | def parse_eval_batch_size(eval_batch_size): | ||||
| """ | """ | ||||
| Parse eval batch size. | Parse eval batch size. | ||||