|
|
|
@@ -14,10 +14,10 @@ |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "minddata/dataset/include/datasets.h" |
|
|
|
#include <algorithm> |
|
|
|
#include <fstream> |
|
|
|
#include <unordered_set> |
|
|
|
#include "minddata/dataset/include/datasets.h" |
|
|
|
#include "minddata/dataset/include/samplers.h" |
|
|
|
#include "minddata/dataset/include/transforms.h" |
|
|
|
// Source dataset headers (in alphabetical order) |
|
|
|
@@ -696,7 +696,7 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
// Helper function to validate dataset dataset files parameter |
|
|
|
// Helper function to validate dataset files parameter |
|
|
|
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) { |
|
|
|
if (dataset_files.empty()) { |
|
|
|
std::string err_msg = dataset_name + ": dataset_files is not specified."; |
|
|
|
@@ -743,7 +743,6 @@ Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_s |
|
|
|
// Helper function to validate dataset sampler parameter |
|
|
|
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) { |
|
|
|
if (sampler == nullptr) { |
|
|
|
MS_LOG(ERROR) << dataset_name << ": Sampler is not constructed correctly, sampler: nullptr"; |
|
|
|
std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
@@ -751,12 +750,13 @@ Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status ValidateStringValue(const std::string &str, const std::unordered_set<std::string> &valid_strings) { |
|
|
|
Status ValidateStringValue(const std::string &dataset_name, const std::string &str, |
|
|
|
const std::unordered_set<std::string> &valid_strings) { |
|
|
|
if (valid_strings.find(str) == valid_strings.end()) { |
|
|
|
std::string mode; |
|
|
|
mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, |
|
|
|
[](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); |
|
|
|
std::string err_msg = str + " does not match any mode in [" + mode + " ]"; |
|
|
|
std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]"; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
@@ -842,7 +842,7 @@ Status CelebANode::ValidateParams() { |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_)); |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"all", "train", "valid", "test"})); |
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue("CelebANode", usage_, {"all", "train", "valid", "test"})); |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
@@ -873,7 +873,7 @@ Status Cifar10Node::ValidateParams() { |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_)); |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); |
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue("Cifar10Node", usage_, {"train", "test", "all"})); |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
@@ -906,7 +906,7 @@ Status Cifar100Node::ValidateParams() { |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_)); |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); |
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue("Cifar100Node", usage_, {"train", "test", "all"})); |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
@@ -945,20 +945,9 @@ CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, |
|
|
|
Status CLUENode::ValidateParams() { |
|
|
|
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); |
|
|
|
|
|
|
|
std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}; |
|
|
|
std::vector<std::string> usage_list = {"train", "test", "eval"}; |
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"})); |
|
|
|
|
|
|
|
if (find(task_list.begin(), task_list.end(), task_) == task_list.end()) { |
|
|
|
std::string err_msg = "task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) { |
|
|
|
std::string err_msg = "usage should be train, test or eval."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", usage_, {"train", "test", "eval"})); |
|
|
|
|
|
|
|
if (num_samples_ < 0) { |
|
|
|
std::string err_msg = "CLUENode: Invalid number of samples: " + std::to_string(num_samples_); |
|
|
|
@@ -1133,18 +1122,12 @@ Status CocoNode::ValidateParams() { |
|
|
|
|
|
|
|
Path annotation_file(annotation_file_); |
|
|
|
if (!annotation_file.Exists()) { |
|
|
|
std::string err_msg = "annotation_file is invalid or not exist"; |
|
|
|
std::string err_msg = "CocoNode: annotation_file is invalid or does not exist."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
std::set<std::string> task_list = {"Detection", "Stuff", "Panoptic", "Keypoint"}; |
|
|
|
auto task_iter = task_list.find(task_); |
|
|
|
if (task_iter == task_list.end()) { |
|
|
|
std::string err_msg = "Invalid task type"; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue("CocoNode", task_, {"Detection", "Stuff", "Panoptic", "Keypoint"})); |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
@@ -1348,7 +1331,7 @@ Status ManifestNode::ValidateParams() { |
|
|
|
for (char c : dataset_file_) { |
|
|
|
auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c); |
|
|
|
if (p != forbidden_symbols.end()) { |
|
|
|
std::string err_msg = "filename should not contains :*?\"<>|`&;\'"; |
|
|
|
std::string err_msg = "ManifestNode: filename should not contain :*?\"<>|`&;\'"; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
@@ -1356,19 +1339,14 @@ Status ManifestNode::ValidateParams() { |
|
|
|
|
|
|
|
Path manifest_file(dataset_file_); |
|
|
|
if (!manifest_file.Exists()) { |
|
|
|
std::string err_msg = "dataset file: [" + dataset_file_ + "] is invalid or not exist"; |
|
|
|
std::string err_msg = "ManifestNode: dataset file: [" + dataset_file_ + "] is invalid or not exist"; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestNode", sampler_)); |
|
|
|
|
|
|
|
std::vector<std::string> usage_list = {"train", "eval", "inference"}; |
|
|
|
if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) { |
|
|
|
std::string err_msg = "usage should be train, eval or inference."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue("ManifestNode", usage_, {"train", "eval", "inference"})); |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
@@ -1536,7 +1514,7 @@ Status MnistNode::ValidateParams() { |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_)); |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); |
|
|
|
RETURN_IF_NOT_OK(ValidateStringValue("MnistNode", usage_, {"train", "test", "all"})); |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
@@ -1753,35 +1731,32 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const |
|
|
|
|
|
|
|
Status VOCNode::ValidateParams() { |
|
|
|
Path dir(dataset_dir_); |
|
|
|
if (!dir.IsDirectory()) { |
|
|
|
std::string err_msg = "Invalid dataset path or no dataset path is specified."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_)); |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCNode", sampler_)); |
|
|
|
|
|
|
|
if (task_ == "Segmentation") { |
|
|
|
if (!class_index_.empty()) { |
|
|
|
std::string err_msg = "class_indexing is invalid in Segmentation task."; |
|
|
|
std::string err_msg = "VOCNode: class_indexing is invalid in Segmentation task."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt"; |
|
|
|
if (!imagesets_file.Exists()) { |
|
|
|
std::string err_msg = "Invalid usage: " + usage_ + ", file does not exist"; |
|
|
|
MS_LOG(ERROR) << "Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; |
|
|
|
std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist"; |
|
|
|
MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
} else if (task_ == "Detection") { |
|
|
|
Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt"; |
|
|
|
if (!imagesets_file.Exists()) { |
|
|
|
std::string err_msg = "Invalid usage: " + usage_ + ", file does not exist"; |
|
|
|
MS_LOG(ERROR) << "Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; |
|
|
|
std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist"; |
|
|
|
MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
} else { |
|
|
|
std::string err_msg = "Invalid task: " + task_; |
|
|
|
std::string err_msg = "VOCNode: Invalid task: " + task_; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
@@ -1859,15 +1834,17 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() { |
|
|
|
|
|
|
|
Status BatchNode::ValidateParams() { |
|
|
|
if (batch_size_ <= 0) { |
|
|
|
std::string err_msg = "Batch: batch_size should be positive integer, but got: " + std::to_string(batch_size_); |
|
|
|
std::string err_msg = "BatchNode: batch_size should be positive integer, but got: " + std::to_string(batch_size_); |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
if (!cols_to_map_.empty()) { |
|
|
|
std::string err_msg = "cols_to_map functionality is not implemented in C++; this should be left empty."; |
|
|
|
std::string err_msg = "BatchNode: cols_to_map functionality is not implemented in C++; this should be left empty."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1906,28 +1883,29 @@ std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { |
|
|
|
|
|
|
|
Status BucketBatchByLengthNode::ValidateParams() { |
|
|
|
if (element_length_function_ == nullptr && column_names_.size() != 1) { |
|
|
|
std::string err_msg = |
|
|
|
"BucketBatchByLength: element_length_function not specified, but not one column name: " + column_names_.size(); |
|
|
|
std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " + |
|
|
|
column_names_.size(); |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
// Check bucket_boundaries: must be positive and strictly increasing |
|
|
|
if (bucket_boundaries_.empty()) { |
|
|
|
std::string err_msg = "BucketBatchByLength: bucket_boundaries cannot be empty."; |
|
|
|
std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
for (int i = 0; i < bucket_boundaries_.size(); i++) { |
|
|
|
if (bucket_boundaries_[i] <= 0) { |
|
|
|
std::string err_msg = "BucketBatchByLength: Invalid non-positive bucket_boundaries, index: "; |
|
|
|
std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: "; |
|
|
|
MS_LOG(ERROR) |
|
|
|
<< "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: " |
|
|
|
<< i << " was: " << bucket_boundaries_[i]; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) { |
|
|
|
std::string err_msg = "BucketBatchByLength: Invalid bucket_boundaries not be strictly increasing."; |
|
|
|
std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing."; |
|
|
|
MS_LOG(ERROR) |
|
|
|
<< "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: " |
|
|
|
<< i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i] |
|
|
|
@@ -1938,20 +1916,24 @@ Status BucketBatchByLengthNode::ValidateParams() { |
|
|
|
|
|
|
|
// Check bucket_batch_sizes: must be positive |
|
|
|
if (bucket_batch_sizes_.empty()) { |
|
|
|
std::string err_msg = "BucketBatchByLength: bucket_batch_sizes must be non-empty"; |
|
|
|
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty"; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) { |
|
|
|
std::string err_msg = "BucketBatchByLength: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1"; |
|
|
|
std::string err_msg = |
|
|
|
"BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1"; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) { |
|
|
|
std::string err_msg = "BucketBatchByLength: bucket_batch_sizes must only contain positive numbers."; |
|
|
|
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1981,26 +1963,26 @@ std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() { |
|
|
|
|
|
|
|
Status BuildVocabNode::ValidateParams() { |
|
|
|
if (vocab_ == nullptr) { |
|
|
|
std::string err_msg = "BuildVocab: vocab is null."; |
|
|
|
std::string err_msg = "BuildVocabNode: vocab is null."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
if (top_k_ <= 0) { |
|
|
|
std::string err_msg = "BuildVocab: top_k should be positive, but got: " + top_k_; |
|
|
|
std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_); |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) { |
|
|
|
std::string err_msg = "BuildVocab: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)"; |
|
|
|
MS_LOG(ERROR) << "BuildVocab: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), " |
|
|
|
std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)"; |
|
|
|
MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), " |
|
|
|
<< "but got [" << freq_range_.first << ", " << freq_range_.second << "]"; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
if (!columns_.empty()) { |
|
|
|
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocab", "columns", columns_)); |
|
|
|
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_)); |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
@@ -2014,15 +1996,17 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : |
|
|
|
|
|
|
|
Status ConcatNode::ValidateParams() { |
|
|
|
if (datasets_.empty()) { |
|
|
|
std::string err_msg = "Concat: concatenated datasets are not specified."; |
|
|
|
std::string err_msg = "ConcatNode: concatenated datasets are not specified."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { |
|
|
|
std::string err_msg = "Concat: concatenated datasets should not be null."; |
|
|
|
std::string err_msg = "ConcatNode: concatenated datasets should not be null."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -2070,7 +2054,7 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { |
|
|
|
|
|
|
|
Status MapNode::ValidateParams() { |
|
|
|
if (operations_.empty()) { |
|
|
|
std::string err_msg = "Map: No operation is specified."; |
|
|
|
std::string err_msg = "MapNode: No operation is specified."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
@@ -2158,8 +2142,8 @@ std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() { |
|
|
|
|
|
|
|
Status RepeatNode::ValidateParams() { |
|
|
|
if (repeat_count_ <= 0 && repeat_count_ != -1) { |
|
|
|
std::string err_msg = |
|
|
|
"Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " + std::to_string(repeat_count_); |
|
|
|
std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " + |
|
|
|
std::to_string(repeat_count_); |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
@@ -2211,10 +2195,11 @@ std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() { |
|
|
|
// Function to validate the parameters for SkipNode |
|
|
|
Status SkipNode::ValidateParams() { |
|
|
|
if (skip_count_ <= -1) { |
|
|
|
std::string err_msg = "Skip: skip_count should not be negative, skip_count: " + std::to_string(skip_count_); |
|
|
|
std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_); |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -2236,7 +2221,7 @@ std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() { |
|
|
|
Status TakeNode::ValidateParams() { |
|
|
|
if (take_count_ <= 0 && take_count_ != -1) { |
|
|
|
std::string err_msg = |
|
|
|
"Take: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); |
|
|
|
"TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
@@ -2252,15 +2237,17 @@ ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datase |
|
|
|
|
|
|
|
Status ZipNode::ValidateParams() { |
|
|
|
if (datasets_.empty()) { |
|
|
|
std::string err_msg = "Zip: datasets to zip are not specified."; |
|
|
|
std::string err_msg = "ZipNode: datasets to zip are not specified."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { |
|
|
|
std::string err_msg = "ZipNode: zip datasets should not be null."; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
|