From: @tiancixiao Reviewed-by: @liucunwei,@heleiwang Signed-off-by: @liucunweitags/v1.2.0-rc1
| @@ -102,7 +102,7 @@ def get_seed(): | |||
| Get the seed. | |||
| Returns: | |||
| Int, seed. | |||
| int, seed. | |||
| """ | |||
| return _config.get_seed() | |||
| @@ -131,7 +131,7 @@ def get_prefetch_size(): | |||
| Get the prefetch size in number of rows. | |||
| Returns: | |||
| Size, total number of rows to be prefetched. | |||
| int, total number of rows to be prefetched. | |||
| """ | |||
| return _config.get_op_connector_size() | |||
| @@ -162,7 +162,7 @@ def get_num_parallel_workers(): | |||
| This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature. | |||
| Returns: | |||
| Int, number of parallel workers to be used as a default for each operation | |||
| int, number of parallel workers to be used as a default for each operation. | |||
| """ | |||
| return _config.get_num_parallel_workers() | |||
| @@ -193,7 +193,7 @@ def get_numa_enable(): | |||
| This is the DEFAULT numa enabled value used for the all process. | |||
| Returns: | |||
| boolean, the default state of numa enabled | |||
| bool, the default state of numa enabled. | |||
| """ | |||
| return _config.get_numa_enable() | |||
| @@ -222,7 +222,7 @@ def get_monitor_sampling_interval(): | |||
| Get the default interval of performance monitor sampling. | |||
| Returns: | |||
| Int, interval (in milliseconds) for performance monitor sampling. | |||
| int, interval (in milliseconds) for performance monitor sampling. | |||
| """ | |||
| return _config.get_monitor_sampling_interval() | |||
| @@ -280,7 +280,8 @@ def get_auto_num_workers(): | |||
| Get the setting (turned on or off) automatic number of workers. | |||
| Returns: | |||
| Bool, whether auto num worker feature is turned on | |||
| bool, whether auto num worker feature is turned on. | |||
| Examples: | |||
| >>> num_workers = ds.config.get_auto_num_workers() | |||
| """ | |||
| @@ -313,7 +314,7 @@ def get_callback_timeout(): | |||
| In case of a deadlock, the wait function will exit after the timeout period. | |||
| Returns: | |||
| Int, the duration in seconds | |||
| int, the duration in seconds. | |||
| """ | |||
| return _config.get_callback_timeout() | |||
| @@ -323,7 +324,7 @@ def __str__(): | |||
| String representation of the configurations. | |||
| Returns: | |||
| Str, configurations. | |||
| str, configurations. | |||
| """ | |||
| return str(_config) | |||
| @@ -80,7 +80,7 @@ def zip(datasets): | |||
| The number of datasets must be more than 1. | |||
| Returns: | |||
| Dataset, ZipDataset. | |||
| ZipDataset, dataset zipped. | |||
| Raises: | |||
| ValueError: If the number of datasets is 1. | |||
| @@ -149,8 +149,8 @@ class Dataset: | |||
| Internal method to create an IR tree. | |||
| Returns: | |||
| ir_tree, The onject of the IR tree. | |||
| dataset, the root dataset of the IR tree. | |||
| DatasetNode, the root node of the IR tree. | |||
| Dataset, the root dataset of the IR tree. | |||
| """ | |||
| parent = self.parent | |||
| self.parent = [] | |||
| @@ -165,7 +165,7 @@ class Dataset: | |||
| Internal method to parse the API tree into an IR tree. | |||
| Returns: | |||
| DatasetNode, The root of the IR tree. | |||
| DatasetNode, the root node of the IR tree. | |||
| """ | |||
| if len(self.parent) > 1: | |||
| raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") | |||
| @@ -197,7 +197,7 @@ class Dataset: | |||
| Args: | |||
| Returns: | |||
| Python dictionary. | |||
| dict, attributes related to the current class. | |||
| """ | |||
| args = dict() | |||
| args["num_parallel_workers"] = self.num_parallel_workers | |||
| @@ -211,7 +211,7 @@ class Dataset: | |||
| filename (str): filename of json file to be saved as | |||
| Returns: | |||
| Str, JSON string of the pipeline. | |||
| str, JSON string of the pipeline. | |||
| """ | |||
| return json.loads(self.parse_tree().to_json(filename)) | |||
| @@ -258,6 +258,9 @@ class Dataset: | |||
| drop_remainder (bool, optional): If True, will drop the last batch for each | |||
| bucket if it is not a full batch (default=False). | |||
| Returns: | |||
| BucketBatchByLengthDataset, dataset bucketed and batched by length. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> | |||
| @@ -371,6 +374,9 @@ class Dataset: | |||
| num_batch (int): the number of batches without blocking at the start of each epoch. | |||
| callback (function): The callback funciton that will be invoked when sync_update is called. | |||
| Returns: | |||
| SyncWaitDataset, dataset added a blocking condition. | |||
| Raises: | |||
| RuntimeError: If condition name already exists. | |||
| @@ -434,7 +440,7 @@ class Dataset: | |||
| return a 'Dataset'. | |||
| Returns: | |||
| Dataset, applied by the function. | |||
| Dataset, dataset applied by the function. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -650,7 +656,7 @@ class Dataset: | |||
| in parallel (default=None). | |||
| Returns: | |||
| FilterDataset, dataset filter. | |||
| FilterDataset, dataset filtered. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -748,6 +754,9 @@ class Dataset: | |||
| """ | |||
| Internal method called by split to calculate absolute split sizes and to | |||
| do some error checking after calculating absolute split sizes. | |||
| Returns: | |||
| int, absolute split sizes of the dataset. | |||
| """ | |||
| # Call get_dataset_size here and check input here because | |||
| # don't want to call this once in check_split and another time in | |||
| @@ -1015,7 +1024,7 @@ class Dataset: | |||
| is specified and special_first is set to default, special_tokens will be prepended | |||
| Returns: | |||
| Vocab node | |||
| Vocab, vocab built from the dataset. | |||
| Example: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -1074,7 +1083,7 @@ class Dataset: | |||
| params(dict): contains more optional parameters of sentencepiece library | |||
| Returns: | |||
| SentencePieceVocab node | |||
| SentencePieceVocab, vocab built from the dataset. | |||
| Example: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -1115,7 +1124,7 @@ class Dataset: | |||
| return a preprogressing 'Dataset'. | |||
| Returns: | |||
| Dataset, applied by the function. | |||
| Dataset, dataset applied by the function. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -1159,7 +1168,7 @@ class Dataset: | |||
| If device is Ascend, features of data will be transferred one by one. The limitation | |||
| of data transmission per time is 256M. | |||
| Return: | |||
| Returns: | |||
| TransferDataset, dataset for transferring. | |||
| """ | |||
| return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue) | |||
| @@ -1287,7 +1296,7 @@ class Dataset: | |||
| use this param to select the conversion method, only take False for better performance (default=True). | |||
| Returns: | |||
| Iterator, list of ndarrays. | |||
| TupleIterator, tuple iterator over the dataset. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -1322,7 +1331,7 @@ class Dataset: | |||
| if output_numpy=False, iterator will output MSTensor (default=False). | |||
| Returns: | |||
| Iterator, dictionary of column name-ndarray pair. | |||
| DictIterator, dictionary iterator over the dataset. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -1352,6 +1361,9 @@ class Dataset: | |||
| """ | |||
| Get Input Index Information | |||
| Returns: | |||
| tuple, tuple of the input index information. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> | |||
| @@ -1409,6 +1421,9 @@ class Dataset: | |||
| def get_col_names(self): | |||
| """ | |||
| Get names of the columns in the dataset | |||
| Returns: | |||
| list, list of column names in the dataset. | |||
| """ | |||
| if self._col_names is None: | |||
| runtime_getter = self._init_tree_getters() | |||
| @@ -1419,8 +1434,8 @@ class Dataset: | |||
| """ | |||
| Get the shapes of output data. | |||
| Return: | |||
| List, list of shapes of each column. | |||
| Returns: | |||
| list, list of shapes of each column. | |||
| """ | |||
| if self.saved_output_shapes is None: | |||
| runtime_getter = self._init_tree_getters() | |||
| @@ -1432,8 +1447,8 @@ class Dataset: | |||
| """ | |||
| Get the types of output data. | |||
| Return: | |||
| List of data types. | |||
| Returns: | |||
| list, list of data types. | |||
| """ | |||
| if self.saved_output_types is None: | |||
| runtime_getter = self._init_tree_getters() | |||
| @@ -1445,8 +1460,8 @@ class Dataset: | |||
| """ | |||
| Get the number of batches in an epoch. | |||
| Return: | |||
| Number, number of batches. | |||
| Returns: | |||
| int, number of batches. | |||
| """ | |||
| if self.dataset_size is None: | |||
| runtime_getter = self._init_size_getter() | |||
| @@ -1457,8 +1472,8 @@ class Dataset: | |||
| """ | |||
| Get the number of classes in a dataset. | |||
| Return: | |||
| Number, number of classes. | |||
| Returns: | |||
| int, number of classes. | |||
| """ | |||
| if self._num_classes is None: | |||
| runtime_getter = self._init_tree_getters() | |||
| @@ -1511,8 +1526,8 @@ class Dataset: | |||
| """ | |||
| Get the size of a batch. | |||
| Return: | |||
| Number, the number of data in a batch. | |||
| Returns: | |||
| int, the number of data in a batch. | |||
| """ | |||
| if self._batch_size is None: | |||
| runtime_getter = self._init_tree_getters() | |||
| @@ -1525,8 +1540,8 @@ class Dataset: | |||
| """ | |||
| Get the replication times in RepeatDataset else 1. | |||
| Return: | |||
| Number, the count of repeat. | |||
| Returns: | |||
| int, the count of repeat. | |||
| """ | |||
| if self._repeat_count is None: | |||
| runtime_getter = self._init_tree_getters() | |||
| @@ -1540,8 +1555,8 @@ class Dataset: | |||
| Get the class index. | |||
| Returns: | |||
| Dict, A str-to-int mapping from label name to index. | |||
| Dict, A str-to-list<int> mapping from label name to index for Coco ONLY. The second number | |||
| dict, a str-to-int mapping from label name to index. | |||
| dict, a str-to-list<int> mapping from label name to index for Coco ONLY. The second number | |||
| in the list is used to indicate the super category | |||
| """ | |||
| if self.children: | |||
| @@ -1588,7 +1603,7 @@ class SourceDataset(Dataset): | |||
| patterns (Union[str, list[str]]): String or list of patterns to be searched. | |||
| Returns: | |||
| List, files. | |||
| list, list of files. | |||
| """ | |||
| if not isinstance(patterns, list): | |||
| @@ -1646,9 +1661,6 @@ class MappableDataset(SourceDataset): | |||
| Args: | |||
| new_sampler (Sampler): The sampler to use for the current dataset. | |||
| Returns: | |||
| Dataset, that uses new_sampler. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> | |||
| @@ -1909,8 +1921,9 @@ class BatchDataset(Dataset): | |||
| Args: | |||
| dataset (Dataset): Dataset to be checked. | |||
| Return: | |||
| True or False. | |||
| Returns: | |||
| bool, whether repeat is used before batch. | |||
| """ | |||
| if isinstance(dataset, RepeatDataset): | |||
| return True | |||
| @@ -1995,18 +2008,12 @@ class BatchInfo(cde.CBatchInfo): | |||
| def get_batch_num(self): | |||
| """ | |||
| Return the batch number of the current batch. | |||
| Return: | |||
| Number, number of the current batch. | |||
| """ | |||
| return | |||
| def get_epoch_num(self): | |||
| """ | |||
| Return the epoch number of the current batch. | |||
| Return: | |||
| Number, number of the current epoch. | |||
| """ | |||
| return | |||
| @@ -2055,8 +2062,8 @@ class BlockReleasePair: | |||
| """ | |||
| Function for handing blocking condition. | |||
| Return: | |||
| True | |||
| Returns: | |||
| bool, True. | |||
| """ | |||
| with self.cv: | |||
| # if disable is true, the always evaluate to true | |||
| @@ -2145,8 +2152,9 @@ class SyncWaitDataset(Dataset): | |||
| Args: | |||
| dataset (Dataset): Dataset to be checked. | |||
| Return: | |||
| True or False. | |||
| Returns: | |||
| bool, whether sync_wait is used before batch. | |||
| """ | |||
| if isinstance(dataset, BatchDataset): | |||
| return True | |||
| @@ -2932,6 +2940,9 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, n | |||
| num_shards (int): Number of shard for sharding. | |||
| shard_id (int): Shard ID. | |||
| non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False). | |||
| Returns: | |||
| Sampler, sampler selected based on user input. | |||
| """ | |||
| if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]): | |||
| return None | |||
| @@ -4180,7 +4191,7 @@ class ManifestDataset(MappableDataset): | |||
| Get the class index. | |||
| Returns: | |||
| Dict, A str-to-int mapping from label name to index. | |||
| dict, a str-to-int mapping from label name to index. | |||
| """ | |||
| if self.class_indexing is None: | |||
| if self._class_indexing is None: | |||
| @@ -4579,7 +4590,7 @@ class Schema: | |||
| Args: | |||
| schema_file(str): Path of schema file (default=None). | |||
| Return: | |||
| Returns: | |||
| Schema object, schema info about dataset. | |||
| Raises: | |||
| @@ -4654,7 +4665,7 @@ class Schema: | |||
| Get a JSON string of the schema. | |||
| Returns: | |||
| Str, JSON string of the schema. | |||
| str, JSON string of the schema. | |||
| """ | |||
| return self.cpp_schema.to_json() | |||
| @@ -4840,7 +4851,7 @@ class VOCDataset(MappableDataset): | |||
| Get the class index. | |||
| Returns: | |||
| Dict, A str-to-int mapping from label name to index. | |||
| dict, a str-to-int mapping from label name to index. | |||
| """ | |||
| if self.task != "Detection": | |||
| raise NotImplementedError("Only 'Detection' support get_class_indexing.") | |||
| @@ -5032,7 +5043,7 @@ class CocoDataset(MappableDataset): | |||
| Get the class index. | |||
| Returns: | |||
| Dict, A str-to-list<int> mapping from label name to index | |||
| dict, a str-to-list<int> mapping from label name to index | |||
| """ | |||
| if self.task not in {"Detection", "Panoptic"}: | |||
| raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.") | |||
| @@ -100,7 +100,7 @@ class GraphData: | |||
| node_type (int): Specify the type of node. | |||
| Returns: | |||
| numpy.ndarray: Array of nodes. | |||
| numpy.ndarray, array of nodes. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -124,7 +124,7 @@ class GraphData: | |||
| edge_type (int): Specify the type of edge. | |||
| Returns: | |||
| numpy.ndarray: array of edges. | |||
| numpy.ndarray, array of edges. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -148,7 +148,7 @@ class GraphData: | |||
| edge_list (Union[list, numpy.ndarray]): The given list of edges. | |||
| Returns: | |||
| numpy.ndarray: Array of nodes. | |||
| numpy.ndarray, array of nodes. | |||
| Raises: | |||
| TypeError: If `edge_list` is not list or ndarray. | |||
| @@ -167,7 +167,7 @@ class GraphData: | |||
| neighbor_type (int): Specify the type of neighbor. | |||
| Returns: | |||
| numpy.ndarray: Array of nodes. | |||
| numpy.ndarray, array of neighbors. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -201,7 +201,7 @@ class GraphData: | |||
| neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop. | |||
| Returns: | |||
| numpy.ndarray: Array of nodes. | |||
| numpy.ndarray, array of neighbors. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -231,7 +231,7 @@ class GraphData: | |||
| neg_neighbor_type (int): Specify the type of negative neighbor. | |||
| Returns: | |||
| numpy.ndarray: Array of nodes. | |||
| numpy.ndarray, array of neighbors. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -260,7 +260,7 @@ class GraphData: | |||
| feature_types (Union[list, numpy.ndarray]): The given list of feature types. | |||
| Returns: | |||
| numpy.ndarray: array of features. | |||
| numpy.ndarray, array of features. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -292,7 +292,7 @@ class GraphData: | |||
| feature_types (Union[list, numpy.ndarray]): The given list of feature types. | |||
| Returns: | |||
| numpy.ndarray: array of features. | |||
| numpy.ndarray, array of features. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -320,7 +320,7 @@ class GraphData: | |||
| the feature information of nodes, the number of edges, the type of edges, and the feature information of edges. | |||
| Returns: | |||
| dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, | |||
| dict, meta information of the graph. The key is node_type, edge_type, node_num, edge_num, | |||
| node_feature_type and edge_feature_type. | |||
| """ | |||
| if self._working_mode == 'server': | |||
| @@ -347,7 +347,7 @@ class GraphData: | |||
| A default value of -1 indicates that no node is given. | |||
| Returns: | |||
| numpy.ndarray: Array of nodes. | |||
| numpy.ndarray, array of nodes. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -128,6 +128,7 @@ class BuiltinSampler: | |||
| User should not extend this class. | |||
| """ | |||
| def __init__(self, num_samples=None): | |||
| self.child_sampler = None | |||
| self.num_samples = num_samples | |||
| @@ -201,7 +202,7 @@ class BuiltinSampler: | |||
| - None | |||
| Returns: | |||
| int, The number of samples, or None | |||
| int, the number of samples, or None | |||
| """ | |||
| if self.child_sampler is not None: | |||
| child_samples = self.child_sampler.get_num_samples() | |||
| @@ -1063,8 +1063,9 @@ class LinearTransformation: | |||
| the dot product with the transformation matrix, and reshapes it back to its original shape. | |||
| Args: | |||
| transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W. | |||
| mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W. | |||
| transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), where | |||
| :math:`D = C \times H \times W`. | |||
| mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where :math:`D = C \times H \times W`. | |||
| Examples: | |||
| >>> from mindspore.dataset.transforms.py_transforms import Compose | |||
| @@ -23,7 +23,7 @@ from PIL import Image | |||
| import mindspore as ms | |||
| import mindspore.dataset as ds | |||
| from mindspore import log | |||
| from mindspore.dataset.engine.datasets import Dataset | |||
| from mindspore.dataset import Dataset | |||
| from mindspore.nn import Cell, SequentialCell | |||
| from mindspore.ops.operations import ExpandDims | |||
| from mindspore.train._utils import check_value_type | |||
| @@ -30,6 +30,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchema | |||
| __all__ = ['FileWriter'] | |||
| class FileWriter: | |||
| """ | |||
| Class to write user defined raw data into MindRecord File series. | |||
| @@ -45,6 +46,7 @@ class FileWriter: | |||
| Raises: | |||
| ParamValueError: If `file_name` or `shard_num` is invalid. | |||
| """ | |||
| def __init__(self, file_name, shard_num=1): | |||
| check_filename(file_name) | |||
| self._file_name = file_name | |||
| @@ -84,7 +86,7 @@ class FileWriter: | |||
| file_name (str): String of MindRecord file name. | |||
| Returns: | |||
| Instance of FileWriter. | |||
| FileWriter, file writer for the opened MindRecord file. | |||
| Raises: | |||
| ParamValueError: If file_name is invalid. | |||
| @@ -118,7 +120,7 @@ class FileWriter: | |||
| desc (str, optional): String of schema description (default=None). | |||
| Returns: | |||
| An integer, schema id. | |||
| int, schema id. | |||
| Raises: | |||
| MRMInvalidSchemaError: If schema is invalid. | |||
| @@ -175,17 +177,17 @@ class FileWriter: | |||
| if field not in v: | |||
| error_data_dic[i] = "for schema, {} th data is wrong, " \ | |||
| "there is not '{}' object in the raw data.".format(i, field) | |||
| "there is not '{}' object in the raw data.".format(i, field) | |||
| continue | |||
| field_type = type(v[field]).__name__ | |||
| if field_type not in VALUE_TYPE_MAP: | |||
| error_data_dic[i] = "for schema, {} th data is wrong, " \ | |||
| "data type for '{}' is not matched.".format(i, field) | |||
| "data type for '{}' is not matched.".format(i, field) | |||
| continue | |||
| if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]: | |||
| error_data_dic[i] = "for schema, {} th data is wrong, " \ | |||
| "data type for '{}' is not matched.".format(i, field) | |||
| "data type for '{}' is not matched.".format(i, field) | |||
| continue | |||
| if field_type == 'ndarray': | |||
| @@ -206,7 +208,6 @@ class FileWriter: | |||
| def open_and_set_header(self): | |||
| """ | |||
| Open writer and set header. | |||
| """ | |||
| if not self._writer.is_open: | |||
| self._writer.open(self._paths) | |||
| @@ -222,6 +223,9 @@ class FileWriter: | |||
| raw_data (list[dict]): List of raw data. | |||
| parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). | |||
| Returns: | |||
| MSRStatus, SUCCESS or FAILED. | |||
| Raises: | |||
| ParamTypeError: If index field is invalid. | |||
| MRMOpenError: If failed to open MindRecord File. | |||
| @@ -330,7 +334,7 @@ class FileWriter: | |||
| v (dict): Sub dict in schema | |||
| Returns: | |||
| bool, True or False. | |||
| bool, whether the array item is valid. | |||
| str, error message. | |||
| """ | |||
| if v['type'] not in VALID_ARRAY_ATTRIBUTES: | |||
| @@ -355,7 +359,7 @@ class FileWriter: | |||
| content (dict): Dict of raw schema. | |||
| Returns: | |||
| bool, True or False. | |||
| bool, whether the schema is valid. | |||
| str, error message. | |||
| """ | |||
| error = '' | |||
| @@ -23,6 +23,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMDefineCategor | |||
| __all__ = ['MindPage'] | |||
| class MindPage: | |||
| """ | |||
| Class to read MindRecord File series in pagination. | |||
| @@ -36,6 +37,7 @@ class MindPage: | |||
| ParamValueError: If `file_name`, `num_consumer` or columns is invalid. | |||
| MRMInitSegmentError: If failed to initialize ShardSegment. | |||
| """ | |||
| def __init__(self, file_name, num_consumer=4): | |||
| if isinstance(file_name, list): | |||
| for f in file_name: | |||
| @@ -69,7 +71,12 @@ class MindPage: | |||
| return self._candidate_fields | |||
| def get_category_fields(self): | |||
| """Return candidate category fields.""" | |||
| """ | |||
| Return candidate category fields. | |||
| Returns: | |||
| list[str], by which data could be grouped. | |||
| """ | |||
| logger.warning("WARN_DEPRECATED: The usage of get_category_fields is deprecated." | |||
| " Please use candidate_fields") | |||
| return self.candidate_fields | |||
| @@ -97,12 +104,22 @@ class MindPage: | |||
| @property | |||
| def category_field(self): | |||
| """Getter function for category fields.""" | |||
| """ | |||
| Getter function for category fields. | |||
| Returns: | |||
| list[str], by which data could be grouped. | |||
| """ | |||
| return self._category_field | |||
| @category_field.setter | |||
| def category_field(self, category_field): | |||
| """Setter function for category field""" | |||
| """ | |||
| Setter function for category field. | |||
| Returns: | |||
| MSRStatus, SUCCESS or FAILED. | |||
| """ | |||
| if not category_field or not isinstance(category_field, str): | |||
| raise ParamTypeError('category_fields', 'str') | |||
| if category_field not in self._candidate_fields: | |||
| @@ -132,7 +149,7 @@ class MindPage: | |||
| num_row (int): Number of rows in a page. | |||
| Returns: | |||
| List, list[dict]. | |||
| list[dict], data queried by category id. | |||
| Raises: | |||
| ParamValueError: If any parameter is invalid. | |||
| @@ -158,7 +175,7 @@ class MindPage: | |||
| num_row (int): Number of row in a page. | |||
| Returns: | |||
| str, read at page. | |||
| list[dict], data queried by category name. | |||
| """ | |||
| if not isinstance(category_name, str): | |||
| raise ParamValueError("Category name should be str.") | |||
| @@ -23,6 +23,7 @@ from .common.exceptions import MRMOpenError, MRMOpenForAppendError, MRMInvalidHe | |||
| __all__ = ['ShardWriter'] | |||
| class ShardWriter: | |||
| """ | |||
| Wrapper class which is represent shardWrite class in c++ module. | |||
| @@ -192,9 +193,11 @@ class ShardWriter: | |||
| if len(blob_data) == 1: | |||
| values = [v for v in blob_data.values()] | |||
| return bytes(values[0]) | |||
| # convert int to bytes | |||
| def int_to_bytes(x: int) -> bytes: | |||
| return x.to_bytes(8, 'big') | |||
| merged = bytes() | |||
| for field, v in blob_data.items(): | |||
| # convert ndarray to bytes | |||
| @@ -209,7 +212,7 @@ class ShardWriter: | |||
| Flush data to disk. | |||
| Returns: | |||
| Class MSRStatus, SUCCESS or FAILED. | |||
| MSRStatus, SUCCESS or FAILED. | |||
| Raises: | |||
| MRMCommitError: If failed to flush data to disk. | |||
| @@ -33,6 +33,7 @@ except ModuleNotFoundError: | |||
| __all__ = ['Cifar100ToMR'] | |||
| class Cifar100ToMR: | |||
| """ | |||
| A class to transform from cifar100 to MindRecord. | |||
| @@ -44,6 +45,7 @@ class Cifar100ToMR: | |||
| Raises: | |||
| ValueError: If source or destination is invalid. | |||
| """ | |||
| def __init__(self, source, destination): | |||
| check_filename(source) | |||
| self.source = source | |||
| @@ -74,7 +76,7 @@ class Cifar100ToMR: | |||
| fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"]. | |||
| Returns: | |||
| SUCCESS or FAILED, whether cifar100 is successfully transformed to MindRecord. | |||
| MSRStatus, whether cifar100 is successfully transformed to MindRecord. | |||
| """ | |||
| if fields and not isinstance(fields, list): | |||
| raise ValueError("The parameter fields should be None or list") | |||
| @@ -114,6 +116,7 @@ class Cifar100ToMR: | |||
| raise t.exception | |||
| return t.res | |||
| def _construct_raw_data(images, fine_labels, coarse_labels): | |||
| """ | |||
| Construct raw data from cifar100 data. | |||
| @@ -124,7 +127,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels): | |||
| coarse_labels (list): coarse label list from cifar100. | |||
| Returns: | |||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||
| list[dict], data dictionary constructed from cifar100. | |||
| """ | |||
| if not cv2: | |||
| raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") | |||
| @@ -141,6 +144,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels): | |||
| raw_data.append(row_data) | |||
| return raw_data | |||
| def _generate_mindrecord(file_name, raw_data, fields, schema_desc): | |||
| """ | |||
| Generate MindRecord file from raw data. | |||
| @@ -153,7 +157,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc): | |||
| schema_desc (str): String of schema description. | |||
| Returns: | |||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||
| MSRStatus, whether successfully written into MindRecord. | |||
| """ | |||
| schema = {"id": {"type": "int64"}, "fine_label": {"type": "int64"}, | |||
| "coarse_label": {"type": "int64"}, "data": {"type": "bytes"}} | |||
| @@ -25,6 +25,7 @@ from .cifar10 import Cifar10 | |||
| from ..common.exceptions import PathNotExistsError | |||
| from ..filewriter import FileWriter | |||
| from ..shardutils import check_filename, ExceptionThread, SUCCESS, FAILED | |||
| try: | |||
| cv2 = import_module("cv2") | |||
| except ModuleNotFoundError: | |||
| @@ -32,6 +33,7 @@ except ModuleNotFoundError: | |||
| __all__ = ['Cifar10ToMR'] | |||
| class Cifar10ToMR: | |||
| """ | |||
| A class to transform from cifar10 to MindRecord. | |||
| @@ -43,6 +45,7 @@ class Cifar10ToMR: | |||
| Raises: | |||
| ValueError: If source or destination is invalid. | |||
| """ | |||
| def __init__(self, source, destination): | |||
| check_filename(source) | |||
| self.source = source | |||
| @@ -73,7 +76,7 @@ class Cifar10ToMR: | |||
| fields (list[str], optional): A list of index fields, e.g.["label"] (default=None). | |||
| Returns: | |||
| SUCCESS or FAILED, whether cifar10 is successfully transformed to MindRecord. | |||
| MSRStatus, whether cifar10 is successfully transformed to MindRecord. | |||
| """ | |||
| if fields and not isinstance(fields, list): | |||
| raise ValueError("The parameter fields should be None or list") | |||
| @@ -109,6 +112,7 @@ class Cifar10ToMR: | |||
| raise t.exception | |||
| return t.res | |||
| def _construct_raw_data(images, labels): | |||
| """ | |||
| Construct raw data from cifar10 data. | |||
| @@ -118,7 +122,7 @@ def _construct_raw_data(images, labels): | |||
| labels (list): label list from cifar10. | |||
| Returns: | |||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||
| list[dict], data dictionary constructed from cifar10. | |||
| """ | |||
| if not cv2: | |||
| raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") | |||
| @@ -133,6 +137,7 @@ def _construct_raw_data(images, labels): | |||
| raw_data.append(row_data) | |||
| return raw_data | |||
| def _generate_mindrecord(file_name, raw_data, fields, schema_desc): | |||
| """ | |||
| Generate MindRecord file from raw data. | |||
| @@ -145,7 +150,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc): | |||
| schema_desc (str): String of schema description. | |||
| Returns: | |||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||
| MSRStatus, whether successfully written into MindRecord. | |||
| """ | |||
| schema = {"id": {"type": "int64"}, "label": {"type": "int64"}, | |||
| "data": {"type": "bytes"}} | |||
| @@ -29,6 +29,7 @@ except ModuleNotFoundError: | |||
| __all__ = ['CsvToMR'] | |||
| class CsvToMR: | |||
| """ | |||
| A class to transform from csv to MindRecord. | |||
| @@ -121,7 +122,7 @@ class CsvToMR: | |||
| Executes transformation from csv to MindRecord. | |||
| Returns: | |||
| SUCCESS or FAILED, whether csv is successfully transformed to MindRecord. | |||
| MSRStatus, whether csv is successfully transformed to MindRecord. | |||
| """ | |||
| if not os.path.exists(self.source): | |||
| raise IOError("Csv file {} do not exist.".format(self.source)) | |||
| @@ -47,6 +47,7 @@ class ImageNetToMR: | |||
| Raises: | |||
| ValueError: If `map_file`, `image_dir` or `destination` is invalid. | |||
| """ | |||
| def __init__(self, map_file, image_dir, destination, partition_number=1): | |||
| check_filename(map_file) | |||
| self.map_file = map_file | |||
| @@ -122,7 +123,7 @@ class ImageNetToMR: | |||
| Executes transformation from imagenet to MindRecord. | |||
| Returns: | |||
| SUCCESS or FAILED, whether imagenet is successfully transformed to MindRecord. | |||
| MSRStatus, whether imagenet is successfully transformed to MindRecord. | |||
| """ | |||
| t0_total = time.time() | |||
| @@ -133,10 +134,10 @@ class ImageNetToMR: | |||
| logger.info("transformed MindRecord schema is: {}".format(imagenet_schema_json)) | |||
| # set the header size | |||
| self.writer.set_header_size(1<<24) | |||
| self.writer.set_header_size(1 << 24) | |||
| # set the page size | |||
| self.writer.set_page_size(1<<26) | |||
| self.writer.set_page_size(1 << 26) | |||
| # create the schema | |||
| self.writer.add_schema(imagenet_schema_json, "imagenet_schema") | |||
| @@ -32,6 +32,7 @@ except ModuleNotFoundError: | |||
| __all__ = ['MnistToMR'] | |||
| class MnistToMR: | |||
| """ | |||
| A class to transform from Mnist to MindRecord. | |||
| @@ -125,7 +126,7 @@ class MnistToMR: | |||
| Executes transformation from Mnist train part to MindRecord. | |||
| Returns: | |||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||
| MSRStatus, whether successfully written into MindRecord. | |||
| """ | |||
| t0_total = time.time() | |||
| @@ -173,7 +174,7 @@ class MnistToMR: | |||
| Executes transformation from Mnist test part to MindRecord. | |||
| Returns: | |||
| SUCCESS or FAILED, whether Mnist is successfully transformed to MindRecord. | |||
| MSRStatus, whether Mnist is successfully transformed to MindRecord. | |||
| """ | |||
| t0_total = time.time() | |||
| @@ -222,7 +223,7 @@ class MnistToMR: | |||
| Executes transformation from Mnist to MindRecord. | |||
| Returns: | |||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||
| MSRStatus, whether successfully written into MindRecord. | |||
| """ | |||
| if not cv2: | |||
| raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") | |||
| @@ -23,7 +23,6 @@ from mindspore import log as logger | |||
| from ..filewriter import FileWriter | |||
| from ..shardutils import check_filename, ExceptionThread | |||
| __all__ = ['TFRecordToMR'] | |||
| SupportedTensorFlowVersion = '1.13.0-rc1' | |||
| @@ -86,9 +85,10 @@ class TFRecordToMR: | |||
| ValueError: If parameter is invalid. | |||
| Exception: when tensorflow module is not found or version is not correct. | |||
| """ | |||
| def __init__(self, source, destination, feature_dict, bytes_fields=None): | |||
| try: | |||
| self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord | |||
| self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord | |||
| except ModuleNotFoundError: | |||
| self.tf = None | |||
| if not self.tf: | |||
| @@ -265,7 +265,7 @@ class TFRecordToMR: | |||
| Execute transformation from TFRecord to MindRecord. | |||
| Returns: | |||
| SUCCESS or FAILED, whether TFRecord is successfuly transformed to MindRecord. | |||
| MSRStatus, whether TFRecord is successfuly transformed to MindRecord. | |||
| """ | |||
| writer = FileWriter(self.destination) | |||
| logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}" | |||
| @@ -22,13 +22,12 @@ import pickle | |||
| import numpy as np | |||
| import pandas as pd | |||
| from mindspore.dataset.engine import GeneratorDataset | |||
| from mindspore.dataset import GeneratorDataset | |||
| import src.constants as rconst | |||
| import src.movielens as movielens | |||
| import src.stat_utils as stat_utils | |||
| DATASET_TO_NUM_USERS_AND_ITEMS = { | |||
| "ml-1m": (6040, 3706), | |||
| "ml-20m": (138493, 26744) | |||
| @@ -205,6 +204,7 @@ class NCFDataset: | |||
| """ | |||
| A dataset for NCF network. | |||
| """ | |||
| def __init__(self, | |||
| pos_users, | |||
| pos_items, | |||
| @@ -407,6 +407,7 @@ class RandomSampler: | |||
| """ | |||
| A random sampler for dataset. | |||
| """ | |||
| def __init__(self, pos_count, num_train_negatives, batch_size): | |||
| self.pos_count = pos_count | |||
| self._num_samples = (1 + num_train_negatives) * self.pos_count | |||
| @@ -433,6 +434,7 @@ class DistributedSamplerOfTrain: | |||
| """ | |||
| A distributed sampler for dataset. | |||
| """ | |||
| def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size): | |||
| """ | |||
| Distributed sampler of training dataset. | |||
| @@ -443,15 +445,16 @@ class DistributedSamplerOfTrain: | |||
| self._batch_size = batch_size | |||
| self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size)) | |||
| self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size)) | |||
| self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size)) | |||
| self._total_num_samples = self._samples_per_rank * self._rank_size | |||
| def __iter__(self): | |||
| """ | |||
| Returns the data after each sampling. | |||
| """ | |||
| indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32())) | |||
| indices = indices.tolist() | |||
| indices.extend(indices[:self._total_num_samples-len(indices)]) | |||
| indices.extend(indices[:self._total_num_samples - len(indices)]) | |||
| indices = indices[self._rank_id:self._total_num_samples:self._rank_size] | |||
| batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)] | |||
| @@ -463,10 +466,12 @@ class DistributedSamplerOfTrain: | |||
| """ | |||
| return self._batchs_per_rank | |||
| class SequenceSampler: | |||
| """ | |||
| A sequence sampler for dataset. | |||
| """ | |||
| def __init__(self, eval_batch_size, num_users): | |||
| self._eval_users_per_batch = int( | |||
| eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) | |||
| @@ -491,10 +496,12 @@ class SequenceSampler: | |||
| """ | |||
| return self._eval_batches_per_epoch | |||
| class DistributedSamplerOfEval: | |||
| """ | |||
| A distributed sampler for eval dataset. | |||
| """ | |||
| def __init__(self, eval_batch_size, num_users, rank_id, rank_size): | |||
| self._eval_users_per_batch = int( | |||
| eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) | |||
| @@ -507,8 +514,8 @@ class DistributedSamplerOfEval: | |||
| self._eval_batch_size = eval_batch_size | |||
| self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size)) | |||
| #self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size)) | |||
| #self._total_num_samples = self._samples_per_rank * self._rank_size | |||
| # self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size)) | |||
| # self._total_num_samples = self._samples_per_rank * self._rank_size | |||
| def __iter__(self): | |||
| indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch) | |||
| @@ -525,6 +532,7 @@ class DistributedSamplerOfEval: | |||
| def __len__(self): | |||
| return self._batchs_per_rank | |||
| def parse_eval_batch_size(eval_batch_size): | |||
| """ | |||
| Parse eval batch size. | |||