Merge pull request !2306 from luoyang/pylinttags/v0.6.0-beta
| @@ -20,6 +20,7 @@ import os | |||||
| import pickle as pkl | import pickle as pkl | ||||
| import numpy as np | import numpy as np | ||||
| import scipy.sparse as sp | import scipy.sparse as sp | ||||
| from mindspore import log as logger | |||||
| # parse args from command line parameter 'graph_api_args' | # parse args from command line parameter 'graph_api_args' | ||||
| # args delimiter is ':' | # args delimiter is ':' | ||||
| @@ -58,7 +59,7 @@ def yield_nodes(task_id=0): | |||||
| Yields: | Yields: | ||||
| data (dict): data row which is dict. | data (dict): data row which is dict. | ||||
| """ | """ | ||||
| print("Node task is {}".format(task_id)) | |||||
| logger.info("Node task is {}".format(task_id)) | |||||
| names = ['x', 'y', 'tx', 'ty', 'allx', 'ally'] | names = ['x', 'y', 'tx', 'ty', 'allx', 'ally'] | ||||
| objects = [] | objects = [] | ||||
| for name in names: | for name in names: | ||||
| @@ -98,7 +99,7 @@ def yield_nodes(task_id=0): | |||||
| line_count += 1 | line_count += 1 | ||||
| node_ids.append(i) | node_ids.append(i) | ||||
| yield node | yield node | ||||
| print('Processed {} lines for nodes.'.format(line_count)) | |||||
| logger.info('Processed {} lines for nodes.'.format(line_count)) | |||||
| def yield_edges(task_id=0): | def yield_edges(task_id=0): | ||||
| @@ -108,21 +109,21 @@ def yield_edges(task_id=0): | |||||
| Yields: | Yields: | ||||
| data (dict): data row which is dict. | data (dict): data row which is dict. | ||||
| """ | """ | ||||
| print("Edge task is {}".format(task_id)) | |||||
| logger.info("Edge task is {}".format(task_id)) | |||||
| with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f: | with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f: | ||||
| graph = pkl.load(f, encoding='latin1') | graph = pkl.load(f, encoding='latin1') | ||||
| line_count = 0 | line_count = 0 | ||||
| for i in graph: | for i in graph: | ||||
| for dst_id in graph[i]: | for dst_id in graph[i]: | ||||
| if not i in node_ids: | if not i in node_ids: | ||||
| print('Source node {} does not exist.'.format(i)) | |||||
| logger.info('Source node {} does not exist.'.format(i)) | |||||
| continue | continue | ||||
| if not dst_id in node_ids: | if not dst_id in node_ids: | ||||
| print('Destination node {} does not exist.'.format( | |||||
| logger.info('Destination node {} does not exist.'.format( | |||||
| dst_id)) | dst_id)) | ||||
| continue | continue | ||||
| edge = {'id': line_count, | edge = {'id': line_count, | ||||
| 'src_id': i, 'dst_id': dst_id, 'type': 0} | 'src_id': i, 'dst_id': dst_id, 'type': 0} | ||||
| line_count += 1 | line_count += 1 | ||||
| yield edge | yield edge | ||||
| print('Processed {} lines for edges.'.format(line_count)) | |||||
| logger.info('Processed {} lines for edges.'.format(line_count)) | |||||
| @@ -16,6 +16,7 @@ | |||||
| Graph data convert tool for MindRecord. | Graph data convert tool for MindRecord. | ||||
| """ | """ | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import log as logger | |||||
| __all__ = ['GraphMapSchema'] | __all__ = ['GraphMapSchema'] | ||||
| @@ -41,6 +42,7 @@ class GraphMapSchema: | |||||
| "edge_feature_index": {"type": "int32", "shape": [-1]} | "edge_feature_index": {"type": "int32", "shape": [-1]} | ||||
| } | } | ||||
| @property | |||||
| def get_schema(self): | def get_schema(self): | ||||
| """ | """ | ||||
| Get schema | Get schema | ||||
| @@ -52,6 +54,7 @@ class GraphMapSchema: | |||||
| Set node features profile | Set node features profile | ||||
| """ | """ | ||||
| if num_features != len(features_data_type) or num_features != len(features_shape): | if num_features != len(features_data_type) or num_features != len(features_shape): | ||||
| logger.info("Node feature profile is not match.") | |||||
| raise ValueError("Node feature profile is not match.") | raise ValueError("Node feature profile is not match.") | ||||
| self.num_node_features = num_features | self.num_node_features = num_features | ||||
| @@ -66,6 +69,7 @@ class GraphMapSchema: | |||||
| Set edge features profile | Set edge features profile | ||||
| """ | """ | ||||
| if num_features != len(features_data_type) or num_features != len(features_shape): | if num_features != len(features_data_type) or num_features != len(features_shape): | ||||
| logger.info("Edge feature profile is not match.") | |||||
| raise ValueError("Edge feature profile is not match.") | raise ValueError("Edge feature profile is not match.") | ||||
| self.num_edge_features = num_features | self.num_edge_features = num_features | ||||
| @@ -83,6 +87,10 @@ class GraphMapSchema: | |||||
| Returns: | Returns: | ||||
| graph data with union schema | graph data with union schema | ||||
| """ | """ | ||||
| if node is None: | |||||
| logger.info("node cannot be None.") | |||||
| raise ValueError("node cannot be None.") | |||||
| node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"], | node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"], | ||||
| "node_feature_index": []} | "node_feature_index": []} | ||||
| for i in range(self.num_node_features): | for i in range(self.num_node_features): | ||||
| @@ -117,6 +125,10 @@ class GraphMapSchema: | |||||
| Returns: | Returns: | ||||
| graph data with union schema | graph data with union schema | ||||
| """ | """ | ||||
| if edge is None: | |||||
| logger.info("edge cannot be None.") | |||||
| raise ValueError("edge cannot be None.") | |||||
| edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e', | edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e', | ||||
| "type": edge["type"], "edge_feature_index": []} | "type": edge["type"], "edge_feature_index": []} | ||||
| @@ -164,7 +164,7 @@ if __name__ == "__main__": | |||||
| num_features, feature_data_types, feature_shapes = mr_api.edge_profile | num_features, feature_data_types, feature_shapes = mr_api.edge_profile | ||||
| graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes) | graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes) | ||||
| graph_schema = graph_map_schema.get_schema() | |||||
| graph_schema = graph_map_schema.get_schema | |||||
| # init writer | # init writer | ||||
| writer = init_writer(graph_schema) | writer = init_writer(graph_schema) | ||||
| @@ -983,7 +983,9 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz | |||||
| continue; | continue; | ||||
| } | } | ||||
| } | } | ||||
| memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size); | |||||
| int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, | |||||
| count * type_size); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric"); | |||||
| out_index += count; | out_index += count; | ||||
| if (i < indices.size() - 1) { | if (i < indices.size() - 1) { | ||||
| src_start = HandleNeg(indices[i + 1], dim_length); // next index | src_start = HandleNeg(indices[i + 1], dim_length); // next index | ||||
| @@ -101,6 +101,9 @@ class BucketBatchByLengthOp : public PipelineOp { | |||||
| std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, | std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, | ||||
| bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); | bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); | ||||
| // Destructor | |||||
| ~BucketBatchByLengthOp() = default; | |||||
| // Might need to batch remaining buckets after receiving eoe, so override this method. | // Might need to batch remaining buckets after receiving eoe, so override this method. | ||||
| // @param int32_t workerId | // @param int32_t workerId | ||||
| // @return Status - The error code returned | // @return Status - The error code returned | ||||
| @@ -36,6 +36,7 @@ GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) | |||||
| : mr_path_(mr_filepath), | : mr_path_(mr_filepath), | ||||
| num_workers_(num_workers), | num_workers_(num_workers), | ||||
| row_id_(0), | row_id_(0), | ||||
| shard_reader_(nullptr), | |||||
| keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} | keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} | ||||
| Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, | Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, | ||||
| @@ -37,7 +37,7 @@ namespace dataset { | |||||
| // Driver method for TreePass | // Driver method for TreePass | ||||
| Status TreePass::Run(ExecutionTree *tree, bool *modified) { | Status TreePass::Run(ExecutionTree *tree, bool *modified) { | ||||
| if (!tree || !modified) { | |||||
| if (tree == nullptr || modified == nullptr) { | |||||
| return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); | return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); | ||||
| } | } | ||||
| return this->RunOnTree(tree, modified); | return this->RunOnTree(tree, modified); | ||||
| @@ -45,7 +45,7 @@ Status TreePass::Run(ExecutionTree *tree, bool *modified) { | |||||
| // Driver method for NodePass | // Driver method for NodePass | ||||
| Status NodePass::Run(ExecutionTree *tree, bool *modified) { | Status NodePass::Run(ExecutionTree *tree, bool *modified) { | ||||
| if (!tree || !modified) { | |||||
| if (tree == nullptr || modified == nullptr) { | |||||
| return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); | return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); | ||||
| } | } | ||||
| std::shared_ptr<DatasetOp> root = tree->root(); | std::shared_ptr<DatasetOp> root = tree->root(); | ||||
| @@ -44,7 +44,7 @@ class ConnectorSize : public Sampling { | |||||
| public: | public: | ||||
| explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {} | explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {} | ||||
| ~ConnectorSize() = default; | |||||
| ~ConnectorSize() override = default; | |||||
| // Driver function for connector size sampling. | // Driver function for connector size sampling. | ||||
| // This function samples the connector size of every nodes within the ExecutionTree | // This function samples the connector size of every nodes within the ExecutionTree | ||||
| @@ -26,6 +26,7 @@ Monitor::Monitor(ExecutionTree *tree) : tree_(tree) { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | ||||
| sampling_interval_ = cfg->monitor_sampling_interval(); | sampling_interval_ = cfg->monitor_sampling_interval(); | ||||
| max_samples_ = 0; | max_samples_ = 0; | ||||
| cur_row_ = 0; | |||||
| } | } | ||||
| Status Monitor::operator()() { | Status Monitor::operator()() { | ||||
| @@ -34,6 +34,8 @@ class Slice { | |||||
| Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} | Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} | ||||
| explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} | explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} | ||||
| ~Slice() = default; | |||||
| std::vector<dsize_t> Indices(dsize_t length) { | std::vector<dsize_t> Indices(dsize_t length) { | ||||
| std::vector<dsize_t> indices; | std::vector<dsize_t> indices; | ||||
| dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); | dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); | ||||
| @@ -29,8 +29,8 @@ Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow | |||||
| BOUNDING_BOX_CHECK(input); | BOUNDING_BOX_CHECK(input); | ||||
| if (distribution_(rnd_)) { | if (distribution_(rnd_)) { | ||||
| // To test bounding boxes algorithm, create random bboxes from image dims | // To test bounding boxes algorithm, create random bboxes from image dims | ||||
| size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes | |||||
| float img_center = (input[0]->shape()[1] / 2); // get the center of the image | |||||
| size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes | |||||
| float img_center = (input[0]->shape()[1] / 2.); // get the center of the image | |||||
| for (int i = 0; i < num_of_boxes; i++) { | for (int i = 0; i < num_of_boxes; i++) { | ||||
| uint32_t b_w = 0; // bounding box width | uint32_t b_w = 0; // bounding box width | ||||
| @@ -49,6 +49,7 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal | |||||
| preserve_unused_token_(preserve_unused_token), | preserve_unused_token_(preserve_unused_token), | ||||
| case_fold_(std::make_unique<CaseFoldOp>()), | case_fold_(std::make_unique<CaseFoldOp>()), | ||||
| nfd_normalize_(std::make_unique<NormalizeUTF8Op>(NormalizeForm::kNfd)), | nfd_normalize_(std::make_unique<NormalizeUTF8Op>(NormalizeForm::kNfd)), | ||||
| normalization_form_(normalization_form), | |||||
| common_normalize_(std::make_unique<NormalizeUTF8Op>(normalization_form)), | common_normalize_(std::make_unique<NormalizeUTF8Op>(normalization_form)), | ||||
| replace_accent_chars_(std::make_unique<RegexReplaceOp>("\\p{Mn}", "")), | replace_accent_chars_(std::make_unique<RegexReplaceOp>("\\p{Mn}", "")), | ||||
| replace_control_chars_(std::make_unique<RegexReplaceOp>("\\p{Cc}|\\p{Cf}", " ")) { | replace_control_chars_(std::make_unique<RegexReplaceOp>("\\p{Cc}|\\p{Cf}", " ")) { | ||||
| @@ -35,9 +35,9 @@ class BasicTokenizerOp : public TensorOp { | |||||
| static const bool kDefKeepWhitespace; | static const bool kDefKeepWhitespace; | ||||
| static const NormalizeForm kDefNormalizationForm; | static const NormalizeForm kDefNormalizationForm; | ||||
| static const bool kDefPreserveUnusedToken; | static const bool kDefPreserveUnusedToken; | ||||
| BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace, | |||||
| NormalizeForm normalization_form = kDefNormalizationForm, | |||||
| bool preserve_unused_token = kDefPreserveUnusedToken); | |||||
| explicit BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace, | |||||
| NormalizeForm normalization_form = kDefNormalizationForm, | |||||
| bool preserve_unused_token = kDefPreserveUnusedToken); | |||||
| ~BasicTokenizerOp() override = default; | ~BasicTokenizerOp() override = default; | ||||
| @@ -28,14 +28,14 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| class BertTokenizerOp : public TensorOp { | class BertTokenizerOp : public TensorOp { | ||||
| public: | public: | ||||
| BertTokenizerOp(const std::shared_ptr<Vocab> &vocab, | |||||
| const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, | |||||
| const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, | |||||
| const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, | |||||
| bool lower_case = BasicTokenizerOp::kDefLowerCase, | |||||
| bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, | |||||
| NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm, | |||||
| bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken) | |||||
| explicit BertTokenizerOp(const std::shared_ptr<Vocab> &vocab, | |||||
| const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, | |||||
| const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, | |||||
| const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, | |||||
| bool lower_case = BasicTokenizerOp::kDefLowerCase, | |||||
| bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, | |||||
| NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm, | |||||
| bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken) | |||||
| : wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token), | : wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token), | ||||
| basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {} | basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {} | ||||
| @@ -48,7 +48,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> { | |||||
| // @return | // @return | ||||
| Status insert(const value_type &val, key_type *key = nullptr) { | Status insert(const value_type &val, key_type *key = nullptr) { | ||||
| key_type my_inx = inx_.fetch_add(1); | key_type my_inx = inx_.fetch_add(1); | ||||
| if (key) { | |||||
| if (key != nullptr) { | |||||
| *key = my_inx; | *key = my_inx; | ||||
| } | } | ||||
| return my_tree::DoInsert(my_inx, val); | return my_tree::DoInsert(my_inx, val); | ||||
| @@ -323,7 +323,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) | |||||
| } | } | ||||
| vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) { | vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) { | ||||
| uint64_t i_size = kUnsignedOne << int_type; | |||||
| uint64_t i_size = kUnsignedOne << static_cast<uint8_t>(int_type); | |||||
| // Get number of elements | // Get number of elements | ||||
| uint64_t src_n_int = src_bytes.size() / i_size; | uint64_t src_n_int = src_bytes.size() / i_size; | ||||
| // Calculate bitmap size (bytes) | // Calculate bitmap size (bytes) | ||||
| @@ -344,7 +344,7 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const | |||||
| // Initialize destination data type | // Initialize destination data type | ||||
| IntegerType dst_int_type = kInt8Type; | IntegerType dst_int_type = kInt8Type; | ||||
| // Shift to next int position | // Shift to next int position | ||||
| uint64_t pos = i * (kUnsignedOne << int_type); | |||||
| uint64_t pos = i * (kUnsignedOne << static_cast<uint8_t>(int_type)); | |||||
| // Narrow down this int | // Narrow down this int | ||||
| int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); | int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); | ||||
| @@ -61,7 +61,7 @@ class Shuffle(str, Enum): | |||||
| @check_zip | @check_zip | ||||
| def zip(datasets): | def zip(datasets): | ||||
| """ | """ | ||||
| Zips the datasets in the input tuple of datasets. | |||||
| Zip the datasets in the input tuple of datasets. | |||||
| Args: | Args: | ||||
| datasets (tuple of class Dataset): A tuple of datasets to be zipped together. | datasets (tuple of class Dataset): A tuple of datasets to be zipped together. | ||||
| @@ -152,7 +152,7 @@ class Dataset: | |||||
| def get_args(self): | def get_args(self): | ||||
| """ | """ | ||||
| Returns attributes (member variables) related to the current class. | |||||
| Return attributes (member variables) related to the current class. | |||||
| Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'. | Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'. | ||||
| @@ -239,7 +239,7 @@ class Dataset: | |||||
| def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, | def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, | ||||
| input_columns=None, pad_info=None): | input_columns=None, pad_info=None): | ||||
| """ | """ | ||||
| Combines batch_size number of consecutive rows into batches. | |||||
| Combine batch_size number of consecutive rows into batches. | |||||
| For any child node, a batch is treated as a single row. | For any child node, a batch is treated as a single row. | ||||
| For any column, all the elements within that column must have the same shape. | For any column, all the elements within that column must have the same shape. | ||||
| @@ -340,7 +340,7 @@ class Dataset: | |||||
| def flat_map(self, func): | def flat_map(self, func): | ||||
| """ | """ | ||||
| Maps `func` to each row in dataset and flatten the result. | |||||
| Map `func` to each row in dataset and flatten the result. | |||||
| The specified `func` is a function that must take one 'Ndarray' as input | The specified `func` is a function that must take one 'Ndarray' as input | ||||
| and return a 'Dataset'. | and return a 'Dataset'. | ||||
| @@ -370,6 +370,7 @@ class Dataset: | |||||
| """ | """ | ||||
| dataset = None | dataset = None | ||||
| if not hasattr(func, '__call__'): | if not hasattr(func, '__call__'): | ||||
| logger.error("func must be a function.") | |||||
| raise TypeError("func must be a function.") | raise TypeError("func must be a function.") | ||||
| for row_data in self: | for row_data in self: | ||||
| @@ -379,6 +380,7 @@ class Dataset: | |||||
| dataset += func(row_data) | dataset += func(row_data) | ||||
| if not isinstance(dataset, Dataset): | if not isinstance(dataset, Dataset): | ||||
| logger.error("flat_map must return a Dataset object.") | |||||
| raise TypeError("flat_map must return a Dataset object.") | raise TypeError("flat_map must return a Dataset object.") | ||||
| return dataset | return dataset | ||||
| @@ -386,7 +388,7 @@ class Dataset: | |||||
| def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, | def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, | ||||
| num_parallel_workers=None, python_multiprocessing=False): | num_parallel_workers=None, python_multiprocessing=False): | ||||
| """ | """ | ||||
| Applies each operation in operations to this dataset. | |||||
| Apply each operation in operations to this dataset. | |||||
| The order of operations is determined by the position of each operation in operations. | The order of operations is determined by the position of each operation in operations. | ||||
| operations[0] will be applied first, then operations[1], then operations[2], etc. | operations[0] will be applied first, then operations[1], then operations[2], etc. | ||||
| @@ -570,7 +572,7 @@ class Dataset: | |||||
| @check_repeat | @check_repeat | ||||
| def repeat(self, count=None): | def repeat(self, count=None): | ||||
| """ | """ | ||||
| Repeats this dataset count times. Repeat indefinitely if the count is None or -1. | |||||
| Repeat this dataset count times. Repeat indefinitely if the count is None or -1. | |||||
| Note: | Note: | ||||
| The order of using repeat and batch reflects the number of batches. Recommend that | The order of using repeat and batch reflects the number of batches. Recommend that | ||||
| @@ -662,13 +664,16 @@ class Dataset: | |||||
| dataset_size = self.get_dataset_size() | dataset_size = self.get_dataset_size() | ||||
| if dataset_size is None or dataset_size <= 0: | if dataset_size is None or dataset_size <= 0: | ||||
| raise RuntimeError("dataset size unknown, unable to split.") | |||||
| raise RuntimeError("dataset_size is unknown, unable to split.") | |||||
| if not isinstance(sizes, list): | |||||
| raise RuntimeError("sizes should be a list.") | |||||
| all_int = all(isinstance(item, int) for item in sizes) | all_int = all(isinstance(item, int) for item in sizes) | ||||
| if all_int: | if all_int: | ||||
| sizes_sum = sum(sizes) | sizes_sum = sum(sizes) | ||||
| if sizes_sum != dataset_size: | if sizes_sum != dataset_size: | ||||
| raise RuntimeError("sum of split sizes {} is not equal to dataset size {}." | |||||
| raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}." | |||||
| .format(sizes_sum, dataset_size)) | .format(sizes_sum, dataset_size)) | ||||
| return sizes | return sizes | ||||
| @@ -676,7 +681,7 @@ class Dataset: | |||||
| for item in sizes: | for item in sizes: | ||||
| absolute_size = int(round(item * dataset_size)) | absolute_size = int(round(item * dataset_size)) | ||||
| if absolute_size == 0: | if absolute_size == 0: | ||||
| raise RuntimeError("split percentage {} is too small.".format(item)) | |||||
| raise RuntimeError("Split percentage {} is too small.".format(item)) | |||||
| absolute_sizes.append(absolute_size) | absolute_sizes.append(absolute_size) | ||||
| absolute_sizes_sum = sum(absolute_sizes) | absolute_sizes_sum = sum(absolute_sizes) | ||||
| @@ -694,7 +699,7 @@ class Dataset: | |||||
| break | break | ||||
| if sum(absolute_sizes) != dataset_size: | if sum(absolute_sizes) != dataset_size: | ||||
| raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}." | |||||
| raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}." | |||||
| .format(absolute_sizes_sum, dataset_size)) | .format(absolute_sizes_sum, dataset_size)) | ||||
| return absolute_sizes | return absolute_sizes | ||||
| @@ -702,7 +707,7 @@ class Dataset: | |||||
| @check_split | @check_split | ||||
| def split(self, sizes, randomize=True): | def split(self, sizes, randomize=True): | ||||
| """ | """ | ||||
| Splits the dataset into smaller, non-overlapping datasets. | |||||
| Split the dataset into smaller, non-overlapping datasets. | |||||
| This is a general purpose split function which can be called from any operator in the pipeline. | This is a general purpose split function which can be called from any operator in the pipeline. | ||||
| There is another, optimized split function, which will be called automatically if ds.split is | There is another, optimized split function, which will be called automatically if ds.split is | ||||
| @@ -759,10 +764,10 @@ class Dataset: | |||||
| >>> train, test = data.split([0.9, 0.1]) | >>> train, test = data.split([0.9, 0.1]) | ||||
| """ | """ | ||||
| if self.is_shuffled(): | if self.is_shuffled(): | ||||
| logger.warning("dataset is shuffled before split.") | |||||
| logger.warning("Dataset is shuffled before split.") | |||||
| if self.is_sharded(): | if self.is_sharded(): | ||||
| raise RuntimeError("dataset should not be sharded before split.") | |||||
| raise RuntimeError("Dataset should not be sharded before split.") | |||||
| absolute_sizes = self._get_absolute_split_sizes(sizes) | absolute_sizes = self._get_absolute_split_sizes(sizes) | ||||
| splits = [] | splits = [] | ||||
| @@ -788,7 +793,7 @@ class Dataset: | |||||
| @check_zip_dataset | @check_zip_dataset | ||||
| def zip(self, datasets): | def zip(self, datasets): | ||||
| """ | """ | ||||
| Zips the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name. | |||||
| Zip the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name. | |||||
| Args: | Args: | ||||
| datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset | datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset | ||||
| @@ -845,7 +850,7 @@ class Dataset: | |||||
| @check_rename | @check_rename | ||||
| def rename(self, input_columns, output_columns): | def rename(self, input_columns, output_columns): | ||||
| """ | """ | ||||
| Renames the columns in input datasets. | |||||
| Rename the columns in input datasets. | |||||
| Args: | Args: | ||||
| input_columns (list[str]): list of names of the input columns. | input_columns (list[str]): list of names of the input columns. | ||||
| @@ -871,7 +876,7 @@ class Dataset: | |||||
| @check_project | @check_project | ||||
| def project(self, columns): | def project(self, columns): | ||||
| """ | """ | ||||
| Projects certain columns in input datasets. | |||||
| Project certain columns in input datasets. | |||||
| The specified columns will be selected from the dataset and passed down | The specified columns will be selected from the dataset and passed down | ||||
| the pipeline in the order specified. The other columns are discarded. | the pipeline in the order specified. The other columns are discarded. | ||||
| @@ -936,7 +941,7 @@ class Dataset: | |||||
| def device_que(self, prefetch_size=None): | def device_que(self, prefetch_size=None): | ||||
| """ | """ | ||||
| Returns a transferredDataset that transfer data through device. | |||||
| Return a transferredDataset that transfer data through device. | |||||
| Args: | Args: | ||||
| prefetch_size (int, optional): prefetch number of records ahead of the | prefetch_size (int, optional): prefetch number of records ahead of the | ||||
| @@ -953,7 +958,7 @@ class Dataset: | |||||
| def to_device(self, num_batch=None): | def to_device(self, num_batch=None): | ||||
| """ | """ | ||||
| Transfers data through CPU, GPU or Ascend devices. | |||||
| Transfer data through CPU, GPU or Ascend devices. | |||||
| Args: | Args: | ||||
| num_batch (int, optional): limit the number of batch to be sent to device (default=None). | num_batch (int, optional): limit the number of batch to be sent to device (default=None). | ||||
| @@ -988,7 +993,7 @@ class Dataset: | |||||
| raise TypeError("Please set device_type in context") | raise TypeError("Please set device_type in context") | ||||
| if device_type not in ('Ascend', 'GPU', 'CPU'): | if device_type not in ('Ascend', 'GPU', 'CPU'): | ||||
| raise ValueError("only support CPU, Ascend, GPU") | |||||
| raise ValueError("Only support CPU, Ascend, GPU") | |||||
| if num_batch is None or num_batch == 0: | if num_batch is None or num_batch == 0: | ||||
| raise ValueError("num_batch is None or 0.") | raise ValueError("num_batch is None or 0.") | ||||
| @@ -1089,7 +1094,7 @@ class Dataset: | |||||
| def _get_pipeline_info(self): | def _get_pipeline_info(self): | ||||
| """ | """ | ||||
| Gets pipeline information. | |||||
| Get pipeline information. | |||||
| """ | """ | ||||
| device_iter = TupleIterator(self) | device_iter = TupleIterator(self) | ||||
| self._output_shapes = device_iter.get_output_shapes() | self._output_shapes = device_iter.get_output_shapes() | ||||
| @@ -1344,7 +1349,7 @@ class MappableDataset(SourceDataset): | |||||
| @check_split | @check_split | ||||
| def split(self, sizes, randomize=True): | def split(self, sizes, randomize=True): | ||||
| """ | """ | ||||
| Splits the dataset into smaller, non-overlapping datasets. | |||||
| Split the dataset into smaller, non-overlapping datasets. | |||||
| There is the optimized split function, which will be called automatically when the dataset | There is the optimized split function, which will be called automatically when the dataset | ||||
| that calls this function is a MappableDataset. | that calls this function is a MappableDataset. | ||||
| @@ -1411,10 +1416,10 @@ class MappableDataset(SourceDataset): | |||||
| >>> train.use_sampler(train_sampler) | >>> train.use_sampler(train_sampler) | ||||
| """ | """ | ||||
| if self.is_shuffled(): | if self.is_shuffled(): | ||||
| logger.warning("dataset is shuffled before split.") | |||||
| logger.warning("Dataset is shuffled before split.") | |||||
| if self.is_sharded(): | if self.is_sharded(): | ||||
| raise RuntimeError("dataset should not be sharded before split.") | |||||
| raise RuntimeError("Dataset should not be sharded before split.") | |||||
| absolute_sizes = self._get_absolute_split_sizes(sizes) | absolute_sizes = self._get_absolute_split_sizes(sizes) | ||||
| splits = [] | splits = [] | ||||
| @@ -1633,7 +1638,7 @@ class BlockReleasePair: | |||||
| def __init__(self, init_release_rows, callback=None): | def __init__(self, init_release_rows, callback=None): | ||||
| if isinstance(init_release_rows, int) and init_release_rows <= 0: | if isinstance(init_release_rows, int) and init_release_rows <= 0: | ||||
| raise ValueError("release_rows need to be greater than 0.") | |||||
| raise ValueError("release_rows need to be greater than 0.") | |||||
| self.row_count = -init_release_rows | self.row_count = -init_release_rows | ||||
| self.cv = threading.Condition() | self.cv = threading.Condition() | ||||
| self.callback = callback | self.callback = callback | ||||
| @@ -2699,10 +2704,10 @@ class MindDataset(MappableDataset): | |||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| if block_reader is True and num_shards is not None: | if block_reader is True and num_shards is not None: | ||||
| raise ValueError("block reader not allowed true when use partitions") | |||||
| raise ValueError("block_reader not allowed true when use partitions") | |||||
| if block_reader is True and shuffle is True: | if block_reader is True and shuffle is True: | ||||
| raise ValueError("block reader not allowed true when use shuffle") | |||||
| raise ValueError("block_reader not allowed true when use shuffle") | |||||
| if block_reader is True: | if block_reader is True: | ||||
| logger.warning("WARN: global shuffle is not used.") | logger.warning("WARN: global shuffle is not used.") | ||||
| @@ -2711,14 +2716,14 @@ class MindDataset(MappableDataset): | |||||
| if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler, | if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler, | ||||
| samplers.DistributedSampler, samplers.RandomSampler, | samplers.DistributedSampler, samplers.RandomSampler, | ||||
| samplers.SequentialSampler)) is False: | samplers.SequentialSampler)) is False: | ||||
| raise ValueError("the sampler is not supported yet.") | |||||
| raise ValueError("The sampler is not supported yet.") | |||||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | ||||
| self.num_samples = num_samples | self.num_samples = num_samples | ||||
| # sampler exclusive | # sampler exclusive | ||||
| if block_reader is True and sampler is not None: | if block_reader is True and sampler is not None: | ||||
| raise ValueError("block reader not allowed true when use sampler") | |||||
| raise ValueError("block_reader not allowed true when use sampler") | |||||
| if num_padded is None: | if num_padded is None: | ||||
| num_padded = 0 | num_padded = 0 | ||||
| @@ -2770,7 +2775,7 @@ class MindDataset(MappableDataset): | |||||
| if value >= 0: | if value >= 0: | ||||
| self._dataset_size = value | self._dataset_size = value | ||||
| else: | else: | ||||
| raise ValueError('set dataset_size with negative value {}'.format(value)) | |||||
| raise ValueError('Set dataset_size with negative value {}'.format(value)) | |||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| if self.shuffle_option is None: | if self.shuffle_option is None: | ||||
| @@ -2872,7 +2877,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): | |||||
| def _fetch_py_sampler_indices(sampler, num_samples): | def _fetch_py_sampler_indices(sampler, num_samples): | ||||
| """ | """ | ||||
| Indices fetcher for python sampler. | |||||
| Indice fetcher for python sampler. | |||||
| """ | """ | ||||
| if num_samples is not None: | if num_samples is not None: | ||||
| sampler_iter = iter(sampler) | sampler_iter = iter(sampler) | ||||
| @@ -3163,7 +3168,7 @@ class GeneratorDataset(MappableDataset): | |||||
| if value >= 0: | if value >= 0: | ||||
| self._dataset_size = value | self._dataset_size = value | ||||
| else: | else: | ||||
| raise ValueError('set dataset_size with negative value {}'.format(value)) | |||||
| raise ValueError('Set dataset_size with negative value {}'.format(value)) | |||||
| def __deepcopy__(self, memodict): | def __deepcopy__(self, memodict): | ||||
| if id(self) in memodict: | if id(self) in memodict: | ||||
| @@ -3313,7 +3318,7 @@ class TFRecordDataset(SourceDataset): | |||||
| if value >= 0: | if value >= 0: | ||||
| self._dataset_size = value | self._dataset_size = value | ||||
| else: | else: | ||||
| raise ValueError('set dataset_size with negative value {}'.format(value)) | |||||
| raise ValueError('Set dataset_size with negative value {}'.format(value)) | |||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| return self.shuffle_files | return self.shuffle_files | ||||
| @@ -4382,7 +4387,9 @@ class CelebADataset(MappableDataset): | |||||
| try: | try: | ||||
| with open(attr_file, 'r') as f: | with open(attr_file, 'r') as f: | ||||
| num_rows = int(f.readline()) | num_rows = int(f.readline()) | ||||
| except Exception: | |||||
| except FileNotFoundError: | |||||
| raise RuntimeError("attr_file not found.") | |||||
| except BaseException: | |||||
| raise RuntimeError("Get dataset size failed from attribution file.") | raise RuntimeError("Get dataset size failed from attribution file.") | ||||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | rows_per_shard = get_num_rows(num_rows, self.num_shards) | ||||
| if self.num_samples is not None: | if self.num_samples is not None: | ||||
| @@ -319,7 +319,7 @@ class PKSampler(BuiltinSampler): | |||||
| raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) | raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) | ||||
| if num_class is not None: | if num_class is not None: | ||||
| raise NotImplementedError | |||||
| raise NotImplementedError("Not support specify num_class") | |||||
| if not isinstance(shuffle, bool): | if not isinstance(shuffle, bool): | ||||
| raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) | raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) | ||||
| @@ -551,8 +551,8 @@ class WeightedRandomSampler(BuiltinSampler): | |||||
| Args: | Args: | ||||
| weights (list[float]): A sequence of weights, not necessarily summing up to 1. | weights (list[float]): A sequence of weights, not necessarily summing up to 1. | ||||
| num_samples (int): Number of elements to sample (default=None, all elements). | |||||
| replacement (bool, optional): If True, put the sample ID back for the next draw (default=True). | |||||
| num_samples (int, optional): Number of elements to sample (default=None, all elements). | |||||
| replacement (bool): If True, put the sample ID back for the next draw (default=True). | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -50,7 +50,7 @@ def check_filename(path): | |||||
| Exception: when error | Exception: when error | ||||
| """ | """ | ||||
| if not isinstance(path, str): | if not isinstance(path, str): | ||||
| raise ValueError("path: {} is not string".format(path)) | |||||
| raise TypeError("path: {} is not string".format(path)) | |||||
| filename = os.path.basename(path) | filename = os.path.basename(path) | ||||
| # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', | # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', | ||||
| @@ -143,7 +143,7 @@ def check_sampler_shuffle_shard_options(param_dict): | |||||
| num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') | num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') | ||||
| if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): | if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): | ||||
| raise ValueError("sampler is not a valid Sampler type.") | |||||
| raise TypeError("sampler is not a valid Sampler type.") | |||||
| if sampler is not None: | if sampler is not None: | ||||
| if shuffle is not None: | if shuffle is not None: | ||||
| @@ -328,13 +328,13 @@ def check_vocdataset(method): | |||||
| if task is None: | if task is None: | ||||
| raise ValueError("task is not provided.") | raise ValueError("task is not provided.") | ||||
| if not isinstance(task, str): | if not isinstance(task, str): | ||||
| raise ValueError("task is not str type.") | |||||
| raise TypeError("task is not str type.") | |||||
| # check mode; required argument | # check mode; required argument | ||||
| mode = param_dict.get('mode') | mode = param_dict.get('mode') | ||||
| if mode is None: | if mode is None: | ||||
| raise ValueError("mode is not provided.") | raise ValueError("mode is not provided.") | ||||
| if not isinstance(mode, str): | if not isinstance(mode, str): | ||||
| raise ValueError("mode is not str type.") | |||||
| raise TypeError("mode is not str type.") | |||||
| imagesets_file = "" | imagesets_file = "" | ||||
| if task == "Segmentation": | if task == "Segmentation": | ||||
| @@ -388,7 +388,7 @@ def check_cocodataset(method): | |||||
| if task is None: | if task is None: | ||||
| raise ValueError("task is not provided.") | raise ValueError("task is not provided.") | ||||
| if not isinstance(task, str): | if not isinstance(task, str): | ||||
| raise ValueError("task is not str type.") | |||||
| raise TypeError("task is not str type.") | |||||
| if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}: | if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}: | ||||
| raise ValueError("Invalid task type") | raise ValueError("Invalid task type") | ||||
| @@ -556,7 +556,7 @@ def check_generatordataset(method): | |||||
| def check_batch_size(batch_size): | def check_batch_size(batch_size): | ||||
| if not (isinstance(batch_size, int) or (callable(batch_size))): | if not (isinstance(batch_size, int) or (callable(batch_size))): | ||||
| raise ValueError("batch_size should either be an int or a callable.") | |||||
| raise TypeError("batch_size should either be an int or a callable.") | |||||
| if callable(batch_size): | if callable(batch_size): | ||||
| sig = ins.signature(batch_size) | sig = ins.signature(batch_size) | ||||
| if len(sig.parameters) != 1: | if len(sig.parameters) != 1: | ||||
| @@ -706,6 +706,7 @@ def check_batch(method): | |||||
| def check_sync_wait(method): | def check_sync_wait(method): | ||||
| """check the input arguments of sync_wait.""" | """check the input arguments of sync_wait.""" | ||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(*args, **kwargs): | def new_method(*args, **kwargs): | ||||
| param_dict = make_param_dict(method, args, kwargs) | param_dict = make_param_dict(method, args, kwargs) | ||||
| @@ -773,7 +774,7 @@ def check_filter(method): | |||||
| param_dict = make_param_dict(method, args, kwargs) | param_dict = make_param_dict(method, args, kwargs) | ||||
| predicate = param_dict.get("predicate") | predicate = param_dict.get("predicate") | ||||
| if not callable(predicate): | if not callable(predicate): | ||||
| raise ValueError("Predicate should be a python function or a callable python object.") | |||||
| raise TypeError("Predicate should be a python function or a callable python object.") | |||||
| nreq_param_int = ['num_parallel_workers'] | nreq_param_int = ['num_parallel_workers'] | ||||
| check_param_type(nreq_param_int, param_dict, int) | check_param_type(nreq_param_int, param_dict, int) | ||||
| @@ -865,7 +866,7 @@ def check_zip_dataset(method): | |||||
| raise ValueError("datasets is not provided.") | raise ValueError("datasets is not provided.") | ||||
| if not isinstance(ds, (tuple, datasets.Dataset)): | if not isinstance(ds, (tuple, datasets.Dataset)): | ||||
| raise ValueError("datasets is not tuple or of type Dataset.") | |||||
| raise TypeError("datasets is not tuple or of type Dataset.") | |||||
| return method(*args, **kwargs) | return method(*args, **kwargs) | ||||
| @@ -885,7 +886,7 @@ def check_concat(method): | |||||
| raise ValueError("datasets is not provided.") | raise ValueError("datasets is not provided.") | ||||
| if not isinstance(ds, (list, datasets.Dataset)): | if not isinstance(ds, (list, datasets.Dataset)): | ||||
| raise ValueError("datasets is not list or of type Dataset.") | |||||
| raise TypeError("datasets is not list or of type Dataset.") | |||||
| return method(*args, **kwargs) | return method(*args, **kwargs) | ||||
| @@ -964,7 +965,7 @@ def check_add_column(method): | |||||
| de_type = param_dict.get("de_type") | de_type = param_dict.get("de_type") | ||||
| if de_type is not None: | if de_type is not None: | ||||
| if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): | if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): | ||||
| raise ValueError("Unknown column type.") | |||||
| raise TypeError("Unknown column type.") | |||||
| else: | else: | ||||
| raise TypeError("Expected non-empty string.") | raise TypeError("Expected non-empty string.") | ||||
| @@ -10,6 +10,6 @@ wheel >= 0.32.0 | |||||
| decorator >= 4.4.0 | decorator >= 4.4.0 | ||||
| setuptools >= 40.8.0 | setuptools >= 40.8.0 | ||||
| matplotlib >= 3.1.3 # for ut test | matplotlib >= 3.1.3 # for ut test | ||||
| opencv-python >= 4.2.0.32 # for ut test | |||||
| opencv-python >= 4.1.2.30 # for ut test | |||||
| sklearn >= 0.0 # for st test | sklearn >= 0.0 # for st test | ||||
| pandas >= 1.0.2 # for ut test | pandas >= 1.0.2 # for ut test | ||||
| @@ -42,15 +42,15 @@ def split_with_invalid_inputs(d): | |||||
| with pytest.raises(RuntimeError) as info: | with pytest.raises(RuntimeError) as info: | ||||
| _, _ = d.split([3, 1]) | _, _ = d.split([3, 1]) | ||||
| assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value) | |||||
| assert "Sum of split sizes 4 is not equal to dataset size 5" in str(info.value) | |||||
| with pytest.raises(RuntimeError) as info: | with pytest.raises(RuntimeError) as info: | ||||
| _, _ = d.split([5, 1]) | _, _ = d.split([5, 1]) | ||||
| assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value) | |||||
| assert "Sum of split sizes 6 is not equal to dataset size 5" in str(info.value) | |||||
| with pytest.raises(RuntimeError) as info: | with pytest.raises(RuntimeError) as info: | ||||
| _, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25]) | _, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25]) | ||||
| assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) | |||||
| assert "Sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | with pytest.raises(ValueError) as info: | ||||
| _, _ = d.split([-0.5, 0.5]) | _, _ = d.split([-0.5, 0.5]) | ||||
| @@ -80,7 +80,7 @@ def test_unmappable_invalid_input(): | |||||
| d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0) | d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0) | ||||
| with pytest.raises(RuntimeError) as info: | with pytest.raises(RuntimeError) as info: | ||||
| _, _ = d.split([4, 1]) | _, _ = d.split([4, 1]) | ||||
| assert "dataset should not be sharded before split" in str(info.value) | |||||
| assert "Dataset should not be sharded before split" in str(info.value) | |||||
| def test_unmappable_split(): | def test_unmappable_split(): | ||||
| @@ -274,7 +274,7 @@ def test_mappable_invalid_input(): | |||||
| d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0) | d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0) | ||||
| with pytest.raises(RuntimeError) as info: | with pytest.raises(RuntimeError) as info: | ||||
| _, _ = d.split([4, 1]) | _, _ = d.split([4, 1]) | ||||
| assert "dataset should not be sharded before split" in str(info.value) | |||||
| assert "Dataset should not be sharded before split" in str(info.value) | |||||
| def test_mappable_split_general(): | def test_mappable_split_general(): | ||||