Merge pull request !2306 from luoyang/pylinttags/v0.6.0-beta
| @@ -20,6 +20,7 @@ import os | |||
| import pickle as pkl | |||
| import numpy as np | |||
| import scipy.sparse as sp | |||
| from mindspore import log as logger | |||
| # parse args from command line parameter 'graph_api_args' | |||
| # args delimiter is ':' | |||
| @@ -58,7 +59,7 @@ def yield_nodes(task_id=0): | |||
| Yields: | |||
| data (dict): data row which is dict. | |||
| """ | |||
| print("Node task is {}".format(task_id)) | |||
| logger.info("Node task is {}".format(task_id)) | |||
| names = ['x', 'y', 'tx', 'ty', 'allx', 'ally'] | |||
| objects = [] | |||
| for name in names: | |||
| @@ -98,7 +99,7 @@ def yield_nodes(task_id=0): | |||
| line_count += 1 | |||
| node_ids.append(i) | |||
| yield node | |||
| print('Processed {} lines for nodes.'.format(line_count)) | |||
| logger.info('Processed {} lines for nodes.'.format(line_count)) | |||
| def yield_edges(task_id=0): | |||
| @@ -108,21 +109,21 @@ def yield_edges(task_id=0): | |||
| Yields: | |||
| data (dict): data row which is dict. | |||
| """ | |||
| print("Edge task is {}".format(task_id)) | |||
| logger.info("Edge task is {}".format(task_id)) | |||
| with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f: | |||
| graph = pkl.load(f, encoding='latin1') | |||
| line_count = 0 | |||
| for i in graph: | |||
| for dst_id in graph[i]: | |||
| if not i in node_ids: | |||
| print('Source node {} does not exist.'.format(i)) | |||
| logger.info('Source node {} does not exist.'.format(i)) | |||
| continue | |||
| if not dst_id in node_ids: | |||
| print('Destination node {} does not exist.'.format( | |||
| logger.info('Destination node {} does not exist.'.format( | |||
| dst_id)) | |||
| continue | |||
| edge = {'id': line_count, | |||
| 'src_id': i, 'dst_id': dst_id, 'type': 0} | |||
| line_count += 1 | |||
| yield edge | |||
| print('Processed {} lines for edges.'.format(line_count)) | |||
| logger.info('Processed {} lines for edges.'.format(line_count)) | |||
| @@ -16,6 +16,7 @@ | |||
| Graph data convert tool for MindRecord. | |||
| """ | |||
| import numpy as np | |||
| from mindspore import log as logger | |||
| __all__ = ['GraphMapSchema'] | |||
| @@ -41,6 +42,7 @@ class GraphMapSchema: | |||
| "edge_feature_index": {"type": "int32", "shape": [-1]} | |||
| } | |||
| @property | |||
| def get_schema(self): | |||
| """ | |||
| Get schema | |||
| @@ -52,6 +54,7 @@ class GraphMapSchema: | |||
| Set node features profile | |||
| """ | |||
| if num_features != len(features_data_type) or num_features != len(features_shape): | |||
| logger.info("Node feature profile is not match.") | |||
| raise ValueError("Node feature profile is not match.") | |||
| self.num_node_features = num_features | |||
| @@ -66,6 +69,7 @@ class GraphMapSchema: | |||
| Set edge features profile | |||
| """ | |||
| if num_features != len(features_data_type) or num_features != len(features_shape): | |||
| logger.info("Edge feature profile is not match.") | |||
| raise ValueError("Edge feature profile is not match.") | |||
| self.num_edge_features = num_features | |||
| @@ -83,6 +87,10 @@ class GraphMapSchema: | |||
| Returns: | |||
| graph data with union schema | |||
| """ | |||
| if node is None: | |||
| logger.info("node cannot be None.") | |||
| raise ValueError("node cannot be None.") | |||
| node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"], | |||
| "node_feature_index": []} | |||
| for i in range(self.num_node_features): | |||
| @@ -117,6 +125,10 @@ class GraphMapSchema: | |||
| Returns: | |||
| graph data with union schema | |||
| """ | |||
| if edge is None: | |||
| logger.info("edge cannot be None.") | |||
| raise ValueError("edge cannot be None.") | |||
| edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e', | |||
| "type": edge["type"], "edge_feature_index": []} | |||
| @@ -164,7 +164,7 @@ if __name__ == "__main__": | |||
| num_features, feature_data_types, feature_shapes = mr_api.edge_profile | |||
| graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes) | |||
| graph_schema = graph_map_schema.get_schema() | |||
| graph_schema = graph_map_schema.get_schema | |||
| # init writer | |||
| writer = init_writer(graph_schema) | |||
| @@ -983,7 +983,9 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz | |||
| continue; | |||
| } | |||
| } | |||
| memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size); | |||
| int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, | |||
| count * type_size); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric"); | |||
| out_index += count; | |||
| if (i < indices.size() - 1) { | |||
| src_start = HandleNeg(indices[i + 1], dim_length); // next index | |||
| @@ -101,6 +101,9 @@ class BucketBatchByLengthOp : public PipelineOp { | |||
| std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, | |||
| bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); | |||
| // Destructor | |||
| ~BucketBatchByLengthOp() = default; | |||
| // Might need to batch remaining buckets after receiving eoe, so override this method. | |||
| // @param int32_t workerId | |||
| // @return Status - The error code returned | |||
| @@ -36,6 +36,7 @@ GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) | |||
| : mr_path_(mr_filepath), | |||
| num_workers_(num_workers), | |||
| row_id_(0), | |||
| shard_reader_(nullptr), | |||
| keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} | |||
| Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, | |||
| @@ -37,7 +37,7 @@ namespace dataset { | |||
| // Driver method for TreePass | |||
| Status TreePass::Run(ExecutionTree *tree, bool *modified) { | |||
| if (!tree || !modified) { | |||
| if (tree == nullptr || modified == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); | |||
| } | |||
| return this->RunOnTree(tree, modified); | |||
| @@ -45,7 +45,7 @@ Status TreePass::Run(ExecutionTree *tree, bool *modified) { | |||
| // Driver method for NodePass | |||
| Status NodePass::Run(ExecutionTree *tree, bool *modified) { | |||
| if (!tree || !modified) { | |||
| if (tree == nullptr || modified == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); | |||
| } | |||
| std::shared_ptr<DatasetOp> root = tree->root(); | |||
| @@ -44,7 +44,7 @@ class ConnectorSize : public Sampling { | |||
| public: | |||
| explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {} | |||
| ~ConnectorSize() = default; | |||
| ~ConnectorSize() override = default; | |||
| // Driver function for connector size sampling. | |||
| // This function samples the connector size of every nodes within the ExecutionTree | |||
| @@ -26,6 +26,7 @@ Monitor::Monitor(ExecutionTree *tree) : tree_(tree) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| sampling_interval_ = cfg->monitor_sampling_interval(); | |||
| max_samples_ = 0; | |||
| cur_row_ = 0; | |||
| } | |||
| Status Monitor::operator()() { | |||
| @@ -34,6 +34,8 @@ class Slice { | |||
| Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} | |||
| explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} | |||
| ~Slice() = default; | |||
| std::vector<dsize_t> Indices(dsize_t length) { | |||
| std::vector<dsize_t> indices; | |||
| dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); | |||
| @@ -29,8 +29,8 @@ Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow | |||
| BOUNDING_BOX_CHECK(input); | |||
| if (distribution_(rnd_)) { | |||
| // To test bounding boxes algorithm, create random bboxes from image dims | |||
| size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes | |||
| float img_center = (input[0]->shape()[1] / 2); // get the center of the image | |||
| size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes | |||
| float img_center = (input[0]->shape()[1] / 2.); // get the center of the image | |||
| for (int i = 0; i < num_of_boxes; i++) { | |||
| uint32_t b_w = 0; // bounding box width | |||
| @@ -49,6 +49,7 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal | |||
| preserve_unused_token_(preserve_unused_token), | |||
| case_fold_(std::make_unique<CaseFoldOp>()), | |||
| nfd_normalize_(std::make_unique<NormalizeUTF8Op>(NormalizeForm::kNfd)), | |||
| normalization_form_(normalization_form), | |||
| common_normalize_(std::make_unique<NormalizeUTF8Op>(normalization_form)), | |||
| replace_accent_chars_(std::make_unique<RegexReplaceOp>("\\p{Mn}", "")), | |||
| replace_control_chars_(std::make_unique<RegexReplaceOp>("\\p{Cc}|\\p{Cf}", " ")) { | |||
| @@ -35,9 +35,9 @@ class BasicTokenizerOp : public TensorOp { | |||
| static const bool kDefKeepWhitespace; | |||
| static const NormalizeForm kDefNormalizationForm; | |||
| static const bool kDefPreserveUnusedToken; | |||
| BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace, | |||
| NormalizeForm normalization_form = kDefNormalizationForm, | |||
| bool preserve_unused_token = kDefPreserveUnusedToken); | |||
| explicit BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace, | |||
| NormalizeForm normalization_form = kDefNormalizationForm, | |||
| bool preserve_unused_token = kDefPreserveUnusedToken); | |||
| ~BasicTokenizerOp() override = default; | |||
| @@ -28,14 +28,14 @@ namespace mindspore { | |||
| namespace dataset { | |||
| class BertTokenizerOp : public TensorOp { | |||
| public: | |||
| BertTokenizerOp(const std::shared_ptr<Vocab> &vocab, | |||
| const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, | |||
| const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, | |||
| const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, | |||
| bool lower_case = BasicTokenizerOp::kDefLowerCase, | |||
| bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, | |||
| NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm, | |||
| bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken) | |||
| explicit BertTokenizerOp(const std::shared_ptr<Vocab> &vocab, | |||
| const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, | |||
| const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, | |||
| const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, | |||
| bool lower_case = BasicTokenizerOp::kDefLowerCase, | |||
| bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, | |||
| NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm, | |||
| bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken) | |||
| : wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token), | |||
| basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {} | |||
| @@ -48,7 +48,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> { | |||
| // @return | |||
| Status insert(const value_type &val, key_type *key = nullptr) { | |||
| key_type my_inx = inx_.fetch_add(1); | |||
| if (key) { | |||
| if (key != nullptr) { | |||
| *key = my_inx; | |||
| } | |||
| return my_tree::DoInsert(my_inx, val); | |||
| @@ -323,7 +323,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) | |||
| } | |||
| vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) { | |||
| uint64_t i_size = kUnsignedOne << int_type; | |||
| uint64_t i_size = kUnsignedOne << static_cast<uint8_t>(int_type); | |||
| // Get number of elements | |||
| uint64_t src_n_int = src_bytes.size() / i_size; | |||
| // Calculate bitmap size (bytes) | |||
| @@ -344,7 +344,7 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const | |||
| // Initialize destination data type | |||
| IntegerType dst_int_type = kInt8Type; | |||
| // Shift to next int position | |||
| uint64_t pos = i * (kUnsignedOne << int_type); | |||
| uint64_t pos = i * (kUnsignedOne << static_cast<uint8_t>(int_type)); | |||
| // Narrow down this int | |||
| int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); | |||
| @@ -61,7 +61,7 @@ class Shuffle(str, Enum): | |||
| @check_zip | |||
| def zip(datasets): | |||
| """ | |||
| Zips the datasets in the input tuple of datasets. | |||
| Zip the datasets in the input tuple of datasets. | |||
| Args: | |||
| datasets (tuple of class Dataset): A tuple of datasets to be zipped together. | |||
| @@ -152,7 +152,7 @@ class Dataset: | |||
| def get_args(self): | |||
| """ | |||
| Returns attributes (member variables) related to the current class. | |||
| Return attributes (member variables) related to the current class. | |||
| Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'. | |||
| @@ -239,7 +239,7 @@ class Dataset: | |||
| def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, | |||
| input_columns=None, pad_info=None): | |||
| """ | |||
| Combines batch_size number of consecutive rows into batches. | |||
| Combine batch_size number of consecutive rows into batches. | |||
| For any child node, a batch is treated as a single row. | |||
| For any column, all the elements within that column must have the same shape. | |||
| @@ -340,7 +340,7 @@ class Dataset: | |||
| def flat_map(self, func): | |||
| """ | |||
| Maps `func` to each row in dataset and flatten the result. | |||
| Map `func` to each row in dataset and flatten the result. | |||
| The specified `func` is a function that must take one 'Ndarray' as input | |||
| and return a 'Dataset'. | |||
| @@ -370,6 +370,7 @@ class Dataset: | |||
| """ | |||
| dataset = None | |||
| if not hasattr(func, '__call__'): | |||
| logger.error("func must be a function.") | |||
| raise TypeError("func must be a function.") | |||
| for row_data in self: | |||
| @@ -379,6 +380,7 @@ class Dataset: | |||
| dataset += func(row_data) | |||
| if not isinstance(dataset, Dataset): | |||
| logger.error("flat_map must return a Dataset object.") | |||
| raise TypeError("flat_map must return a Dataset object.") | |||
| return dataset | |||
| @@ -386,7 +388,7 @@ class Dataset: | |||
| def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, | |||
| num_parallel_workers=None, python_multiprocessing=False): | |||
| """ | |||
| Applies each operation in operations to this dataset. | |||
| Apply each operation in operations to this dataset. | |||
| The order of operations is determined by the position of each operation in operations. | |||
| operations[0] will be applied first, then operations[1], then operations[2], etc. | |||
| @@ -570,7 +572,7 @@ class Dataset: | |||
| @check_repeat | |||
| def repeat(self, count=None): | |||
| """ | |||
| Repeats this dataset count times. Repeat indefinitely if the count is None or -1. | |||
| Repeat this dataset count times. Repeat indefinitely if the count is None or -1. | |||
| Note: | |||
| The order of using repeat and batch reflects the number of batches. Recommend that | |||
| @@ -662,13 +664,16 @@ class Dataset: | |||
| dataset_size = self.get_dataset_size() | |||
| if dataset_size is None or dataset_size <= 0: | |||
| raise RuntimeError("dataset size unknown, unable to split.") | |||
| raise RuntimeError("dataset_size is unknown, unable to split.") | |||
| if not isinstance(sizes, list): | |||
| raise RuntimeError("sizes should be a list.") | |||
| all_int = all(isinstance(item, int) for item in sizes) | |||
| if all_int: | |||
| sizes_sum = sum(sizes) | |||
| if sizes_sum != dataset_size: | |||
| raise RuntimeError("sum of split sizes {} is not equal to dataset size {}." | |||
| raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}." | |||
| .format(sizes_sum, dataset_size)) | |||
| return sizes | |||
| @@ -676,7 +681,7 @@ class Dataset: | |||
| for item in sizes: | |||
| absolute_size = int(round(item * dataset_size)) | |||
| if absolute_size == 0: | |||
| raise RuntimeError("split percentage {} is too small.".format(item)) | |||
| raise RuntimeError("Split percentage {} is too small.".format(item)) | |||
| absolute_sizes.append(absolute_size) | |||
| absolute_sizes_sum = sum(absolute_sizes) | |||
| @@ -694,7 +699,7 @@ class Dataset: | |||
| break | |||
| if sum(absolute_sizes) != dataset_size: | |||
| raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}." | |||
| raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}." | |||
| .format(absolute_sizes_sum, dataset_size)) | |||
| return absolute_sizes | |||
| @@ -702,7 +707,7 @@ class Dataset: | |||
| @check_split | |||
| def split(self, sizes, randomize=True): | |||
| """ | |||
| Splits the dataset into smaller, non-overlapping datasets. | |||
| Split the dataset into smaller, non-overlapping datasets. | |||
| This is a general purpose split function which can be called from any operator in the pipeline. | |||
| There is another, optimized split function, which will be called automatically if ds.split is | |||
| @@ -759,10 +764,10 @@ class Dataset: | |||
| >>> train, test = data.split([0.9, 0.1]) | |||
| """ | |||
| if self.is_shuffled(): | |||
| logger.warning("dataset is shuffled before split.") | |||
| logger.warning("Dataset is shuffled before split.") | |||
| if self.is_sharded(): | |||
| raise RuntimeError("dataset should not be sharded before split.") | |||
| raise RuntimeError("Dataset should not be sharded before split.") | |||
| absolute_sizes = self._get_absolute_split_sizes(sizes) | |||
| splits = [] | |||
| @@ -788,7 +793,7 @@ class Dataset: | |||
| @check_zip_dataset | |||
| def zip(self, datasets): | |||
| """ | |||
| Zips the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name. | |||
| Zip the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name. | |||
| Args: | |||
| datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset | |||
| @@ -845,7 +850,7 @@ class Dataset: | |||
| @check_rename | |||
| def rename(self, input_columns, output_columns): | |||
| """ | |||
| Renames the columns in input datasets. | |||
| Rename the columns in input datasets. | |||
| Args: | |||
| input_columns (list[str]): list of names of the input columns. | |||
| @@ -871,7 +876,7 @@ class Dataset: | |||
| @check_project | |||
| def project(self, columns): | |||
| """ | |||
| Projects certain columns in input datasets. | |||
| Project certain columns in input datasets. | |||
| The specified columns will be selected from the dataset and passed down | |||
| the pipeline in the order specified. The other columns are discarded. | |||
| @@ -936,7 +941,7 @@ class Dataset: | |||
| def device_que(self, prefetch_size=None): | |||
| """ | |||
| Returns a transferredDataset that transfer data through device. | |||
| Return a transferredDataset that transfer data through device. | |||
| Args: | |||
| prefetch_size (int, optional): prefetch number of records ahead of the | |||
| @@ -953,7 +958,7 @@ class Dataset: | |||
| def to_device(self, num_batch=None): | |||
| """ | |||
| Transfers data through CPU, GPU or Ascend devices. | |||
| Transfer data through CPU, GPU or Ascend devices. | |||
| Args: | |||
| num_batch (int, optional): limit the number of batch to be sent to device (default=None). | |||
| @@ -988,7 +993,7 @@ class Dataset: | |||
| raise TypeError("Please set device_type in context") | |||
| if device_type not in ('Ascend', 'GPU', 'CPU'): | |||
| raise ValueError("only support CPU, Ascend, GPU") | |||
| raise ValueError("Only support CPU, Ascend, GPU") | |||
| if num_batch is None or num_batch == 0: | |||
| raise ValueError("num_batch is None or 0.") | |||
| @@ -1089,7 +1094,7 @@ class Dataset: | |||
| def _get_pipeline_info(self): | |||
| """ | |||
| Gets pipeline information. | |||
| Get pipeline information. | |||
| """ | |||
| device_iter = TupleIterator(self) | |||
| self._output_shapes = device_iter.get_output_shapes() | |||
| @@ -1344,7 +1349,7 @@ class MappableDataset(SourceDataset): | |||
| @check_split | |||
| def split(self, sizes, randomize=True): | |||
| """ | |||
| Splits the dataset into smaller, non-overlapping datasets. | |||
| Split the dataset into smaller, non-overlapping datasets. | |||
| There is the optimized split function, which will be called automatically when the dataset | |||
| that calls this function is a MappableDataset. | |||
| @@ -1411,10 +1416,10 @@ class MappableDataset(SourceDataset): | |||
| >>> train.use_sampler(train_sampler) | |||
| """ | |||
| if self.is_shuffled(): | |||
| logger.warning("dataset is shuffled before split.") | |||
| logger.warning("Dataset is shuffled before split.") | |||
| if self.is_sharded(): | |||
| raise RuntimeError("dataset should not be sharded before split.") | |||
| raise RuntimeError("Dataset should not be sharded before split.") | |||
| absolute_sizes = self._get_absolute_split_sizes(sizes) | |||
| splits = [] | |||
| @@ -1633,7 +1638,7 @@ class BlockReleasePair: | |||
| def __init__(self, init_release_rows, callback=None): | |||
| if isinstance(init_release_rows, int) and init_release_rows <= 0: | |||
| raise ValueError("release_rows need to be greater than 0.") | |||
| raise ValueError("release_rows need to be greater than 0.") | |||
| self.row_count = -init_release_rows | |||
| self.cv = threading.Condition() | |||
| self.callback = callback | |||
| @@ -2699,10 +2704,10 @@ class MindDataset(MappableDataset): | |||
| self.shard_id = shard_id | |||
| if block_reader is True and num_shards is not None: | |||
| raise ValueError("block reader not allowed true when use partitions") | |||
| raise ValueError("block_reader not allowed true when use partitions") | |||
| if block_reader is True and shuffle is True: | |||
| raise ValueError("block reader not allowed true when use shuffle") | |||
| raise ValueError("block_reader not allowed true when use shuffle") | |||
| if block_reader is True: | |||
| logger.warning("WARN: global shuffle is not used.") | |||
| @@ -2711,14 +2716,14 @@ class MindDataset(MappableDataset): | |||
| if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler, | |||
| samplers.DistributedSampler, samplers.RandomSampler, | |||
| samplers.SequentialSampler)) is False: | |||
| raise ValueError("the sampler is not supported yet.") | |||
| raise ValueError("The sampler is not supported yet.") | |||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | |||
| self.num_samples = num_samples | |||
| # sampler exclusive | |||
| if block_reader is True and sampler is not None: | |||
| raise ValueError("block reader not allowed true when use sampler") | |||
| raise ValueError("block_reader not allowed true when use sampler") | |||
| if num_padded is None: | |||
| num_padded = 0 | |||
| @@ -2770,7 +2775,7 @@ class MindDataset(MappableDataset): | |||
| if value >= 0: | |||
| self._dataset_size = value | |||
| else: | |||
| raise ValueError('set dataset_size with negative value {}'.format(value)) | |||
| raise ValueError('Set dataset_size with negative value {}'.format(value)) | |||
| def is_shuffled(self): | |||
| if self.shuffle_option is None: | |||
| @@ -2872,7 +2877,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): | |||
| def _fetch_py_sampler_indices(sampler, num_samples): | |||
| """ | |||
| Indices fetcher for python sampler. | |||
| Indice fetcher for python sampler. | |||
| """ | |||
| if num_samples is not None: | |||
| sampler_iter = iter(sampler) | |||
| @@ -3163,7 +3168,7 @@ class GeneratorDataset(MappableDataset): | |||
| if value >= 0: | |||
| self._dataset_size = value | |||
| else: | |||
| raise ValueError('set dataset_size with negative value {}'.format(value)) | |||
| raise ValueError('Set dataset_size with negative value {}'.format(value)) | |||
| def __deepcopy__(self, memodict): | |||
| if id(self) in memodict: | |||
| @@ -3313,7 +3318,7 @@ class TFRecordDataset(SourceDataset): | |||
| if value >= 0: | |||
| self._dataset_size = value | |||
| else: | |||
| raise ValueError('set dataset_size with negative value {}'.format(value)) | |||
| raise ValueError('Set dataset_size with negative value {}'.format(value)) | |||
| def is_shuffled(self): | |||
| return self.shuffle_files | |||
| @@ -4382,7 +4387,9 @@ class CelebADataset(MappableDataset): | |||
| try: | |||
| with open(attr_file, 'r') as f: | |||
| num_rows = int(f.readline()) | |||
| except Exception: | |||
| except FileNotFoundError: | |||
| raise RuntimeError("attr_file not found.") | |||
| except BaseException: | |||
| raise RuntimeError("Get dataset size failed from attribution file.") | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| if self.num_samples is not None: | |||
| @@ -319,7 +319,7 @@ class PKSampler(BuiltinSampler): | |||
| raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) | |||
| if num_class is not None: | |||
| raise NotImplementedError | |||
| raise NotImplementedError("Not support specify num_class") | |||
| if not isinstance(shuffle, bool): | |||
| raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) | |||
| @@ -551,8 +551,8 @@ class WeightedRandomSampler(BuiltinSampler): | |||
| Args: | |||
| weights (list[float]): A sequence of weights, not necessarily summing up to 1. | |||
| num_samples (int): Number of elements to sample (default=None, all elements). | |||
| replacement (bool, optional): If True, put the sample ID back for the next draw (default=True). | |||
| num_samples (int, optional): Number of elements to sample (default=None, all elements). | |||
| replacement (bool): If True, put the sample ID back for the next draw (default=True). | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -50,7 +50,7 @@ def check_filename(path): | |||
| Exception: when error | |||
| """ | |||
| if not isinstance(path, str): | |||
| raise ValueError("path: {} is not string".format(path)) | |||
| raise TypeError("path: {} is not string".format(path)) | |||
| filename = os.path.basename(path) | |||
| # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', | |||
| @@ -143,7 +143,7 @@ def check_sampler_shuffle_shard_options(param_dict): | |||
| num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') | |||
| if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): | |||
| raise ValueError("sampler is not a valid Sampler type.") | |||
| raise TypeError("sampler is not a valid Sampler type.") | |||
| if sampler is not None: | |||
| if shuffle is not None: | |||
| @@ -328,13 +328,13 @@ def check_vocdataset(method): | |||
| if task is None: | |||
| raise ValueError("task is not provided.") | |||
| if not isinstance(task, str): | |||
| raise ValueError("task is not str type.") | |||
| raise TypeError("task is not str type.") | |||
| # check mode; required argument | |||
| mode = param_dict.get('mode') | |||
| if mode is None: | |||
| raise ValueError("mode is not provided.") | |||
| if not isinstance(mode, str): | |||
| raise ValueError("mode is not str type.") | |||
| raise TypeError("mode is not str type.") | |||
| imagesets_file = "" | |||
| if task == "Segmentation": | |||
| @@ -388,7 +388,7 @@ def check_cocodataset(method): | |||
| if task is None: | |||
| raise ValueError("task is not provided.") | |||
| if not isinstance(task, str): | |||
| raise ValueError("task is not str type.") | |||
| raise TypeError("task is not str type.") | |||
| if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}: | |||
| raise ValueError("Invalid task type") | |||
| @@ -556,7 +556,7 @@ def check_generatordataset(method): | |||
| def check_batch_size(batch_size): | |||
| if not (isinstance(batch_size, int) or (callable(batch_size))): | |||
| raise ValueError("batch_size should either be an int or a callable.") | |||
| raise TypeError("batch_size should either be an int or a callable.") | |||
| if callable(batch_size): | |||
| sig = ins.signature(batch_size) | |||
| if len(sig.parameters) != 1: | |||
| @@ -706,6 +706,7 @@ def check_batch(method): | |||
| def check_sync_wait(method): | |||
| """check the input arguments of sync_wait.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| @@ -773,7 +774,7 @@ def check_filter(method): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| predicate = param_dict.get("predicate") | |||
| if not callable(predicate): | |||
| raise ValueError("Predicate should be a python function or a callable python object.") | |||
| raise TypeError("Predicate should be a python function or a callable python object.") | |||
| nreq_param_int = ['num_parallel_workers'] | |||
| check_param_type(nreq_param_int, param_dict, int) | |||
| @@ -865,7 +866,7 @@ def check_zip_dataset(method): | |||
| raise ValueError("datasets is not provided.") | |||
| if not isinstance(ds, (tuple, datasets.Dataset)): | |||
| raise ValueError("datasets is not tuple or of type Dataset.") | |||
| raise TypeError("datasets is not tuple or of type Dataset.") | |||
| return method(*args, **kwargs) | |||
| @@ -885,7 +886,7 @@ def check_concat(method): | |||
| raise ValueError("datasets is not provided.") | |||
| if not isinstance(ds, (list, datasets.Dataset)): | |||
| raise ValueError("datasets is not list or of type Dataset.") | |||
| raise TypeError("datasets is not list or of type Dataset.") | |||
| return method(*args, **kwargs) | |||
| @@ -964,7 +965,7 @@ def check_add_column(method): | |||
| de_type = param_dict.get("de_type") | |||
| if de_type is not None: | |||
| if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): | |||
| raise ValueError("Unknown column type.") | |||
| raise TypeError("Unknown column type.") | |||
| else: | |||
| raise TypeError("Expected non-empty string.") | |||
| @@ -10,6 +10,6 @@ wheel >= 0.32.0 | |||
| decorator >= 4.4.0 | |||
| setuptools >= 40.8.0 | |||
| matplotlib >= 3.1.3 # for ut test | |||
| opencv-python >= 4.2.0.32 # for ut test | |||
| opencv-python >= 4.1.2.30 # for ut test | |||
| sklearn >= 0.0 # for st test | |||
| pandas >= 1.0.2 # for ut test | |||
| @@ -42,15 +42,15 @@ def split_with_invalid_inputs(d): | |||
| with pytest.raises(RuntimeError) as info: | |||
| _, _ = d.split([3, 1]) | |||
| assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value) | |||
| assert "Sum of split sizes 4 is not equal to dataset size 5" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| _, _ = d.split([5, 1]) | |||
| assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value) | |||
| assert "Sum of split sizes 6 is not equal to dataset size 5" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| _, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25]) | |||
| assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) | |||
| assert "Sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| _, _ = d.split([-0.5, 0.5]) | |||
| @@ -80,7 +80,7 @@ def test_unmappable_invalid_input(): | |||
| d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0) | |||
| with pytest.raises(RuntimeError) as info: | |||
| _, _ = d.split([4, 1]) | |||
| assert "dataset should not be sharded before split" in str(info.value) | |||
| assert "Dataset should not be sharded before split" in str(info.value) | |||
| def test_unmappable_split(): | |||
| @@ -274,7 +274,7 @@ def test_mappable_invalid_input(): | |||
| d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0) | |||
| with pytest.raises(RuntimeError) as info: | |||
| _, _ = d.split([4, 1]) | |||
| assert "dataset should not be sharded before split" in str(info.value) | |||
| assert "Dataset should not be sharded before split" in str(info.value) | |||
| def test_mappable_split_general(): | |||