| @@ -14,10 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/include/datasets.h" | |||
| #include <algorithm> | |||
| #include <fstream> | |||
| #include <unordered_set> | |||
| #include "minddata/dataset/include/datasets.h" | |||
| #include "minddata/dataset/include/samplers.h" | |||
| #include "minddata/dataset/include/transforms.h" | |||
| // Source dataset headers (in alphabetical order) | |||
| @@ -696,7 +696,7 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset dataset files parameter | |||
| // Helper function to validate dataset files parameter | |||
| Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) { | |||
| if (dataset_files.empty()) { | |||
| std::string err_msg = dataset_name + ": dataset_files is not specified."; | |||
| @@ -743,7 +743,6 @@ Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_s | |||
| // Helper function to validate dataset sampler parameter | |||
| Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) { | |||
| if (sampler == nullptr) { | |||
| MS_LOG(ERROR) << dataset_name << ": Sampler is not constructed correctly, sampler: nullptr"; | |||
| std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| @@ -751,12 +750,13 @@ Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared | |||
| return Status::OK(); | |||
| } | |||
| Status ValidateStringValue(const std::string &str, const std::unordered_set<std::string> &valid_strings) { | |||
| Status ValidateStringValue(const std::string &dataset_name, const std::string &str, | |||
| const std::unordered_set<std::string> &valid_strings) { | |||
| if (valid_strings.find(str) == valid_strings.end()) { | |||
| std::string mode; | |||
| mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, | |||
| [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); | |||
| std::string err_msg = str + " does not match any mode in [" + mode + " ]"; | |||
| std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| @@ -842,7 +842,7 @@ Status CelebANode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_)); | |||
| RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"all", "train", "valid", "test"})); | |||
| RETURN_IF_NOT_OK(ValidateStringValue("CelebANode", usage_, {"all", "train", "valid", "test"})); | |||
| return Status::OK(); | |||
| } | |||
| @@ -873,7 +873,7 @@ Status Cifar10Node::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_)); | |||
| RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); | |||
| RETURN_IF_NOT_OK(ValidateStringValue("Cifar10Node", usage_, {"train", "test", "all"})); | |||
| return Status::OK(); | |||
| } | |||
| @@ -906,7 +906,7 @@ Status Cifar100Node::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_)); | |||
| RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); | |||
| RETURN_IF_NOT_OK(ValidateStringValue("Cifar100Node", usage_, {"train", "test", "all"})); | |||
| return Status::OK(); | |||
| } | |||
| @@ -945,20 +945,9 @@ CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, | |||
| Status CLUENode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); | |||
| std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}; | |||
| std::vector<std::string> usage_list = {"train", "test", "eval"}; | |||
| RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"})); | |||
| if (find(task_list.begin(), task_list.end(), task_) == task_list.end()) { | |||
| std::string err_msg = "task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) { | |||
| std::string err_msg = "usage should be train, test or eval."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", usage_, {"train", "test", "eval"})); | |||
| if (num_samples_ < 0) { | |||
| std::string err_msg = "CLUENode: Invalid number of samples: " + std::to_string(num_samples_); | |||
| @@ -1133,18 +1122,12 @@ Status CocoNode::ValidateParams() { | |||
| Path annotation_file(annotation_file_); | |||
| if (!annotation_file.Exists()) { | |||
| std::string err_msg = "annotation_file is invalid or not exist"; | |||
| std::string err_msg = "CocoNode: annotation_file is invalid or does not exist."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| std::set<std::string> task_list = {"Detection", "Stuff", "Panoptic", "Keypoint"}; | |||
| auto task_iter = task_list.find(task_); | |||
| if (task_iter == task_list.end()) { | |||
| std::string err_msg = "Invalid task type"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| RETURN_IF_NOT_OK(ValidateStringValue("CocoNode", task_, {"Detection", "Stuff", "Panoptic", "Keypoint"})); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1348,7 +1331,7 @@ Status ManifestNode::ValidateParams() { | |||
| for (char c : dataset_file_) { | |||
| auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c); | |||
| if (p != forbidden_symbols.end()) { | |||
| std::string err_msg = "filename should not contains :*?\"<>|`&;\'"; | |||
| std::string err_msg = "ManifestNode: filename should not contain :*?\"<>|`&;\'"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| @@ -1356,19 +1339,14 @@ Status ManifestNode::ValidateParams() { | |||
| Path manifest_file(dataset_file_); | |||
| if (!manifest_file.Exists()) { | |||
| std::string err_msg = "dataset file: [" + dataset_file_ + "] is invalid or not exist"; | |||
| std::string err_msg = "ManifestNode: dataset file: [" + dataset_file_ + "] is invalid or not exist"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestNode", sampler_)); | |||
| std::vector<std::string> usage_list = {"train", "eval", "inference"}; | |||
| if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) { | |||
| std::string err_msg = "usage should be train, eval or inference."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| RETURN_IF_NOT_OK(ValidateStringValue("ManifestNode", usage_, {"train", "eval", "inference"})); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1536,7 +1514,7 @@ Status MnistNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_)); | |||
| RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); | |||
| RETURN_IF_NOT_OK(ValidateStringValue("MnistNode", usage_, {"train", "test", "all"})); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1753,35 +1731,32 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const | |||
| Status VOCNode::ValidateParams() { | |||
| Path dir(dataset_dir_); | |||
| if (!dir.IsDirectory()) { | |||
| std::string err_msg = "Invalid dataset path or no dataset path is specified."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_)); | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCNode", sampler_)); | |||
| if (task_ == "Segmentation") { | |||
| if (!class_index_.empty()) { | |||
| std::string err_msg = "class_indexing is invalid in Segmentation task."; | |||
| std::string err_msg = "VOCNode: class_indexing is invalid in Segmentation task."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt"; | |||
| if (!imagesets_file.Exists()) { | |||
| std::string err_msg = "Invalid usage: " + usage_ + ", file does not exist"; | |||
| MS_LOG(ERROR) << "Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; | |||
| std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist"; | |||
| MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| } else if (task_ == "Detection") { | |||
| Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt"; | |||
| if (!imagesets_file.Exists()) { | |||
| std::string err_msg = "Invalid usage: " + usage_ + ", file does not exist"; | |||
| MS_LOG(ERROR) << "Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; | |||
| std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist"; | |||
| MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| } else { | |||
| std::string err_msg = "Invalid task: " + task_; | |||
| std::string err_msg = "VOCNode: Invalid task: " + task_; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| @@ -1859,15 +1834,17 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() { | |||
| Status BatchNode::ValidateParams() { | |||
| if (batch_size_ <= 0) { | |||
| std::string err_msg = "Batch: batch_size should be positive integer, but got: " + std::to_string(batch_size_); | |||
| std::string err_msg = "BatchNode: batch_size should be positive integer, but got: " + std::to_string(batch_size_); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (!cols_to_map_.empty()) { | |||
| std::string err_msg = "cols_to_map functionality is not implemented in C++; this should be left empty."; | |||
| std::string err_msg = "BatchNode: cols_to_map functionality is not implemented in C++; this should be left empty."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1906,28 +1883,29 @@ std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { | |||
| Status BucketBatchByLengthNode::ValidateParams() { | |||
| if (element_length_function_ == nullptr && column_names_.size() != 1) { | |||
| std::string err_msg = | |||
| "BucketBatchByLength: element_length_function not specified, but not one column name: " + column_names_.size(); | |||
| std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " + | |||
| column_names_.size(); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| // Check bucket_boundaries: must be positive and strictly increasing | |||
| if (bucket_boundaries_.empty()) { | |||
| std::string err_msg = "BucketBatchByLength: bucket_boundaries cannot be empty."; | |||
| std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| for (int i = 0; i < bucket_boundaries_.size(); i++) { | |||
| if (bucket_boundaries_[i] <= 0) { | |||
| std::string err_msg = "BucketBatchByLength: Invalid non-positive bucket_boundaries, index: "; | |||
| std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: "; | |||
| MS_LOG(ERROR) | |||
| << "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: " | |||
| << i << " was: " << bucket_boundaries_[i]; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) { | |||
| std::string err_msg = "BucketBatchByLength: Invalid bucket_boundaries not be strictly increasing."; | |||
| std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing."; | |||
| MS_LOG(ERROR) | |||
| << "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: " | |||
| << i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i] | |||
| @@ -1938,20 +1916,24 @@ Status BucketBatchByLengthNode::ValidateParams() { | |||
| // Check bucket_batch_sizes: must be positive | |||
| if (bucket_batch_sizes_.empty()) { | |||
| std::string err_msg = "BucketBatchByLength: bucket_batch_sizes must be non-empty"; | |||
| std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) { | |||
| std::string err_msg = "BucketBatchByLength: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1"; | |||
| std::string err_msg = | |||
| "BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) { | |||
| std::string err_msg = "BucketBatchByLength: bucket_batch_sizes must only contain positive numbers."; | |||
| std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1981,26 +1963,26 @@ std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() { | |||
| Status BuildVocabNode::ValidateParams() { | |||
| if (vocab_ == nullptr) { | |||
| std::string err_msg = "BuildVocab: vocab is null."; | |||
| std::string err_msg = "BuildVocabNode: vocab is null."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (top_k_ <= 0) { | |||
| std::string err_msg = "BuildVocab: top_k should be positive, but got: " + top_k_; | |||
| std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) { | |||
| std::string err_msg = "BuildVocab: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)"; | |||
| MS_LOG(ERROR) << "BuildVocab: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), " | |||
| std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)"; | |||
| MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), " | |||
| << "but got [" << freq_range_.first << ", " << freq_range_.second << "]"; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (!columns_.empty()) { | |||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocab", "columns", columns_)); | |||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_)); | |||
| } | |||
| return Status::OK(); | |||
| @@ -2014,15 +1996,17 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : | |||
| Status ConcatNode::ValidateParams() { | |||
| if (datasets_.empty()) { | |||
| std::string err_msg = "Concat: concatenated datasets are not specified."; | |||
| std::string err_msg = "ConcatNode: concatenated datasets are not specified."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { | |||
| std::string err_msg = "Concat: concatenated datasets should not be null."; | |||
| std::string err_msg = "ConcatNode: concatenated datasets should not be null."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -2070,7 +2054,7 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { | |||
| Status MapNode::ValidateParams() { | |||
| if (operations_.empty()) { | |||
| std::string err_msg = "Map: No operation is specified."; | |||
| std::string err_msg = "MapNode: No operation is specified."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| @@ -2158,8 +2142,8 @@ std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() { | |||
| Status RepeatNode::ValidateParams() { | |||
| if (repeat_count_ <= 0 && repeat_count_ != -1) { | |||
| std::string err_msg = | |||
| "Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " + std::to_string(repeat_count_); | |||
| std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " + | |||
| std::to_string(repeat_count_); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| @@ -2211,10 +2195,11 @@ std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() { | |||
| // Function to validate the parameters for SkipNode | |||
| Status SkipNode::ValidateParams() { | |||
| if (skip_count_ <= -1) { | |||
| std::string err_msg = "Skip: skip_count should not be negative, skip_count: " + std::to_string(skip_count_); | |||
| std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -2236,7 +2221,7 @@ std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() { | |||
| Status TakeNode::ValidateParams() { | |||
| if (take_count_ <= 0 && take_count_ != -1) { | |||
| std::string err_msg = | |||
| "Take: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); | |||
| "TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| @@ -2252,15 +2237,17 @@ ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datase | |||
| Status ZipNode::ValidateParams() { | |||
| if (datasets_.empty()) { | |||
| std::string err_msg = "Zip: datasets to zip are not specified."; | |||
| std::string err_msg = "ZipNode: datasets to zip are not specified."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { | |||
| std::string err_msg = "ZipNode: zip datasets should not be null."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||