| @@ -70,6 +70,7 @@ void BatchNode::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status BatchNode::ValidateParams() { | Status BatchNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (batch_size_ <= 0) { | if (batch_size_ <= 0) { | ||||
| std::string err_msg = "BatchNode: batch_size should be positive integer, but got: " + std::to_string(batch_size_); | std::string err_msg = "BatchNode: batch_size should be positive integer, but got: " + std::to_string(batch_size_); | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -94,6 +94,7 @@ Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *n | |||||
| } | } | ||||
| Status BucketBatchByLengthNode::ValidateParams() { | Status BucketBatchByLengthNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (element_length_function_ == nullptr && column_names_.size() != 1) { | if (element_length_function_ == nullptr && column_names_.size() != 1) { | ||||
| std::string err_msg = | std::string err_msg = | ||||
| "BucketBatchByLengthNode: when element_length_function is not specified, size of column_name must be 1 but is: " + | "BucketBatchByLengthNode: when element_length_function is not specified, size of column_name must be 1 but is: " + | ||||
| @@ -63,6 +63,7 @@ Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *no | |||||
| } | } | ||||
| Status BuildSentenceVocabNode::ValidateParams() { | Status BuildSentenceVocabNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (vocab_ == nullptr) { | if (vocab_ == nullptr) { | ||||
| std::string err_msg = "BuildSentenceVocabNode: vocab is null."; | std::string err_msg = "BuildSentenceVocabNode: vocab is null."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -59,6 +59,7 @@ Status BuildVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) | |||||
| } | } | ||||
| Status BuildVocabNode::ValidateParams() { | Status BuildVocabNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (vocab_ == nullptr) { | if (vocab_ == nullptr) { | ||||
| std::string err_msg = "BuildVocabNode: vocab is null."; | std::string err_msg = "BuildVocabNode: vocab is null."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -49,6 +49,7 @@ std::shared_ptr<DatasetNode> ConcatNode::Copy() { | |||||
| void ConcatNode::Print(std::ostream &out) const { out << Name(); } | void ConcatNode::Print(std::ostream &out) const { out << Name(); } | ||||
| Status ConcatNode::ValidateParams() { | Status ConcatNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (children_.size() < 2) { | if (children_.size() < 2) { | ||||
| std::string err_msg = "ConcatNode: concatenated datasets are not specified."; | std::string err_msg = "ConcatNode: concatenated datasets are not specified."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <limits> | |||||
| #include <memory> | #include <memory> | ||||
| #include <set> | #include <set> | ||||
| @@ -220,7 +221,7 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) { | |||||
| return shared_from_this(); | return shared_from_this(); | ||||
| } | } | ||||
| DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) { | |||||
| DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}), dataset_size_(-1) { | |||||
| // Fetch some default value from config manager | // Fetch some default value from config manager | ||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | ||||
| num_workers_ = cfg->num_parallel_workers(); | num_workers_ = cfg->num_parallel_workers(); | ||||
| @@ -418,6 +419,13 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz | |||||
| RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); | RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); | ||||
| } | } | ||||
| } | } | ||||
| Status DatasetNode::ValidateParams() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| num_workers_ > 0 && num_workers_ < std::numeric_limits<uint16_t>::max(), | |||||
| Name() + "'s num_workers=" + std::to_string(num_workers_) + ", this value is less than 1 or too large."); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status MappableSourceNode::Accept(IRNodePass *p, bool *modified) { | Status MappableSourceNode::Accept(IRNodePass *p, bool *modified) { | ||||
| return p->Visit(shared_from_base<MappableSourceNode>(), modified); | return p->Visit(shared_from_base<MappableSourceNode>(), modified); | ||||
| @@ -150,9 +150,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||||
| /// \return Status Status::OK() if build successfully | /// \return Status Status::OK() if build successfully | ||||
| virtual Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) = 0; | virtual Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) = 0; | ||||
| /// \brief Pure virtual function for derived class to implement parameters validation | |||||
| /// \brief base virtual function for derived class to implement parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| virtual Status ValidateParams() = 0; | |||||
| virtual Status ValidateParams(); | |||||
| /// \brief Pure virtual function for derived class to get the shard id of specific node | /// \brief Pure virtual function for derived class to get the shard id of specific node | ||||
| /// \return Status Status::OK() if get shard id successfully | /// \return Status Status::OK() if get shard id successfully | ||||
| @@ -262,7 +262,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||||
| std::vector<std::shared_ptr<DatasetNode>> children_; | std::vector<std::shared_ptr<DatasetNode>> children_; | ||||
| DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase | DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase | ||||
| std::shared_ptr<DatasetCache> cache_; | std::shared_ptr<DatasetCache> cache_; | ||||
| int64_t dataset_size_ = -1; | |||||
| int64_t dataset_size_; | |||||
| int32_t num_workers_; | int32_t num_workers_; | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| int32_t connector_que_size_; | int32_t connector_que_size_; | ||||
| @@ -48,6 +48,7 @@ Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| // Function to validate the parameters for EpochCtrlNode | // Function to validate the parameters for EpochCtrlNode | ||||
| Status EpochCtrlNode::ValidateParams() { | Status EpochCtrlNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (num_epochs_ <= 0 && num_epochs_ != -1) { | if (num_epochs_ <= 0 && num_epochs_ != -1) { | ||||
| std::string err_msg = | std::string err_msg = | ||||
| "EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); | "EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); | ||||
| @@ -49,6 +49,7 @@ Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| } | } | ||||
| Status FilterNode::ValidateParams() { | Status FilterNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (predicate_ == nullptr) { | if (predicate_ == nullptr) { | ||||
| std::string err_msg = "FilterNode: predicate is not specified."; | std::string err_msg = "FilterNode: predicate is not specified."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -82,6 +82,7 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| } | } | ||||
| Status MapNode::ValidateParams() { | Status MapNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (operations_.empty()) { | if (operations_.empty()) { | ||||
| std::string err_msg = "MapNode: No operation is specified."; | std::string err_msg = "MapNode: No operation is specified."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -40,6 +40,7 @@ std::shared_ptr<DatasetNode> ProjectNode::Copy() { | |||||
| void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; } | void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; } | ||||
| Status ProjectNode::ValidateParams() { | Status ProjectNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (columns_.empty()) { | if (columns_.empty()) { | ||||
| std::string err_msg = "ProjectNode: No columns are specified."; | std::string err_msg = "ProjectNode: No columns are specified."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -43,6 +43,7 @@ void RenameNode::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status RenameNode::ValidateParams() { | Status RenameNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (input_columns_.size() != output_columns_.size()) { | if (input_columns_.size() != output_columns_.size()) { | ||||
| std::string err_msg = "RenameNode: input and output columns must be the same size"; | std::string err_msg = "RenameNode: input and output columns must be the same size"; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -43,6 +43,7 @@ Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| } | } | ||||
| Status RepeatNode::ValidateParams() { | Status RepeatNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (repeat_count_ <= 0 && repeat_count_ != -1) { | if (repeat_count_ <= 0 && repeat_count_ != -1) { | ||||
| std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " + | std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " + | ||||
| std::to_string(repeat_count_); | std::to_string(repeat_count_); | ||||
| @@ -49,6 +49,7 @@ Status RootNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| // Function to validate the parameters for RootNode | // Function to validate the parameters for RootNode | ||||
| Status RootNode::ValidateParams() { | Status RootNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (num_epochs_ <= 0 && num_epochs_ != -1) { | if (num_epochs_ <= 0 && num_epochs_ != -1) { | ||||
| std::string err_msg = | std::string err_msg = | ||||
| "RootNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); | "RootNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); | ||||
| @@ -51,6 +51,7 @@ Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| // Function to validate the parameters for ShuffleNode | // Function to validate the parameters for ShuffleNode | ||||
| Status ShuffleNode::ValidateParams() { | Status ShuffleNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (shuffle_size_ <= 1) { | if (shuffle_size_ <= 1) { | ||||
| std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_); | std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_); | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -45,6 +45,7 @@ Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| // Function to validate the parameters for SkipNode | // Function to validate the parameters for SkipNode | ||||
| Status SkipNode::ValidateParams() { | Status SkipNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (skip_count_ <= -1) { | if (skip_count_ <= -1) { | ||||
| std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_); | std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_); | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -50,6 +50,7 @@ void AlbumNode::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status AlbumNode::ValidateParams() { | Status AlbumNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_)); | ||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumNode", {schema_path_})); | RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumNode", {schema_path_})); | ||||
| @@ -50,6 +50,7 @@ void CelebANode::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status CelebANode::ValidateParams() { | Status CelebANode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); | ||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_)); | RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_)); | ||||
| @@ -43,6 +43,7 @@ void Cifar100Node::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status Cifar100Node::ValidateParams() { | Status Cifar100Node::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); | ||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_)); | RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_)); | ||||
| @@ -43,6 +43,7 @@ void Cifar10Node::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status Cifar10Node::ValidateParams() { | Status Cifar10Node::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); | ||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_)); | RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_)); | ||||
| @@ -53,6 +53,7 @@ void CLUENode::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status CLUENode::ValidateParams() { | Status CLUENode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); | RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); | ||||
| RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"})); | RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"})); | ||||
| @@ -46,6 +46,7 @@ std::shared_ptr<DatasetNode> CocoNode::Copy() { | |||||
| void CocoNode::Print(std::ostream &out) const { out << Name(); } | void CocoNode::Print(std::ostream &out) const { out << Name(); } | ||||
| Status CocoNode::ValidateParams() { | Status CocoNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); | ||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoNode", sampler_)); | RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoNode", sampler_)); | ||||
| @@ -61,6 +61,7 @@ void CSVNode::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status CSVNode::ValidateParams() { | Status CSVNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_)); | RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_)); | ||||
| if (field_delim_ == '"' || field_delim_ == '\r' || field_delim_ == '\n') { | if (field_delim_ == '"' || field_delim_ == '\r' || field_delim_ == '\n') { | ||||
| @@ -83,7 +83,10 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| } | } | ||||
| // no validation is needed for generator op. | // no validation is needed for generator op. | ||||
| Status GeneratorNode::ValidateParams() { return Status::OK(); } | |||||
| Status GeneratorNode::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GeneratorNode::GetShardId(int32_t *shard_id) { | Status GeneratorNode::GetShardId(int32_t *shard_id) { | ||||
| RETURN_UNEXPECTED_IF_NULL(shard_id); | RETURN_UNEXPECTED_IF_NULL(shard_id); | ||||
| @@ -53,6 +53,7 @@ void ImageFolderNode::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status ImageFolderNode::ValidateParams() { | Status ImageFolderNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); | ||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderNode", sampler_)); | RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderNode", sampler_)); | ||||
| @@ -57,6 +57,7 @@ void ManifestNode::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status ManifestNode::ValidateParams() { | Status ManifestNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; | std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; | ||||
| for (char c : dataset_file_) { | for (char c : dataset_file_) { | ||||
| auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c); | auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c); | ||||
| @@ -67,6 +67,7 @@ std::shared_ptr<DatasetNode> MindDataNode::Copy() { | |||||
| void MindDataNode::Print(std::ostream &out) const { out << Name() + "(file:" + dataset_file_ + ",...)"; } | void MindDataNode::Print(std::ostream &out) const { out << Name() + "(file:" + dataset_file_ + ",...)"; } | ||||
| Status MindDataNode::ValidateParams() { | Status MindDataNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (!search_for_pattern_ && dataset_files_.size() > 4096) { | if (!search_for_pattern_ && dataset_files_.size() > 4096) { | ||||
| std::string err_msg = | std::string err_msg = | ||||
| "MindDataNode: length of dataset_file must be less than or equal to 4096, dataset_file length: " + | "MindDataNode: length of dataset_file must be less than or equal to 4096, dataset_file length: " + | ||||
| @@ -40,6 +40,7 @@ std::shared_ptr<DatasetNode> MnistNode::Copy() { | |||||
| void MnistNode::Print(std::ostream &out) const { out << Name(); } | void MnistNode::Print(std::ostream &out) const { out << Name(); } | ||||
| Status MnistNode::ValidateParams() { | Status MnistNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); | ||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_)); | RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_)); | ||||
| @@ -41,6 +41,7 @@ void RandomNode::Print(std::ostream &out) const { out << Name() + "(num_row:" + | |||||
| // ValidateParams for RandomNode | // ValidateParams for RandomNode | ||||
| Status RandomNode::ValidateParams() { | Status RandomNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (total_rows_ < 0) { | if (total_rows_ < 0) { | ||||
| std::string err_msg = | std::string err_msg = | ||||
| "RandomNode: total_rows must be greater than or equal 0, now get " + std::to_string(total_rows_); | "RandomNode: total_rows must be greater than or equal 0, now get " + std::to_string(total_rows_); | ||||
| @@ -55,6 +55,7 @@ void TextFileNode::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status TextFileNode::ValidateParams() { | Status TextFileNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_)); | RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_)); | ||||
| if (num_samples_ < 0) { | if (num_samples_ < 0) { | ||||
| @@ -49,6 +49,7 @@ void TFRecordNode::Print(std::ostream &out) const { | |||||
| // Validator for TFRecordNode | // Validator for TFRecordNode | ||||
| Status TFRecordNode::ValidateParams() { | Status TFRecordNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (dataset_files_.empty()) { | if (dataset_files_.empty()) { | ||||
| std::string err_msg = "TFRecordNode: dataset_files is not specified."; | std::string err_msg = "TFRecordNode: dataset_files is not specified."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -49,6 +49,7 @@ std::shared_ptr<DatasetNode> VOCNode::Copy() { | |||||
| void VOCNode::Print(std::ostream &out) const { out << Name(); } | void VOCNode::Print(std::ostream &out) const { out << Name(); } | ||||
| Status VOCNode::ValidateParams() { | Status VOCNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| Path dir(dataset_dir_); | Path dir(dataset_dir_); | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_)); | ||||
| @@ -52,7 +52,10 @@ Status SyncWaitNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| } | } | ||||
| // Function to validate the parameters for SyncWaitNode | // Function to validate the parameters for SyncWaitNode | ||||
| Status SyncWaitNode::ValidateParams() { return Status::OK(); } | |||||
| Status SyncWaitNode::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -46,6 +46,7 @@ Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| // Function to validate the parameters for TakeNode | // Function to validate the parameters for TakeNode | ||||
| Status TakeNode::ValidateParams() { | Status TakeNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (take_count_ <= 0 && take_count_ != -1) { | if (take_count_ <= 0 && take_count_ != -1) { | ||||
| std::string err_msg = | std::string err_msg = | ||||
| "TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); | "TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); | ||||
| @@ -57,6 +57,7 @@ void TransferNode::Print(std::ostream &out) const { | |||||
| // Validator for TransferNode | // Validator for TransferNode | ||||
| Status TransferNode::ValidateParams() { | Status TransferNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (total_batch_ < 0) { | if (total_batch_ < 0) { | ||||
| std::string err_msg = "TransferNode: Total batches should be >= 0, value given: "; | std::string err_msg = "TransferNode: Total batches should be >= 0, value given: "; | ||||
| MS_LOG(ERROR) << err_msg << total_batch_; | MS_LOG(ERROR) << err_msg << total_batch_; | ||||
| @@ -41,6 +41,7 @@ std::shared_ptr<DatasetNode> ZipNode::Copy() { | |||||
| void ZipNode::Print(std::ostream &out) const { out << Name(); } | void ZipNode::Print(std::ostream &out) const { out << Name(); } | ||||
| Status ZipNode::ValidateParams() { | Status ZipNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||||
| if (children_.size() < 2) { | if (children_.size() < 2) { | ||||
| std::string err_msg = "ZipNode: input datasets are not specified."; | std::string err_msg = "ZipNode: input datasets are not specified."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| @@ -1833,3 +1833,20 @@ TEST_F(MindDataTestPipeline, TestZipSuccess2) { | |||||
| // Manually terminate the pipeline | // Manually terminate the pipeline | ||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestNumWorkersValidate) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNumWorkersValidate."; | |||||
| // Create an ImageFolder Dataset | |||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path); | |||||
| // ds needs to be non nullptr otherwise, the subsequent logic will core dump | |||||
| ASSERT_NE(ds, nullptr); | |||||
| // test if set num_workers=-1 | |||||
| EXPECT_EQ(ds->SetNumWorkers(-1)->CreateIterator(), nullptr); | |||||
| // test if set num_workers can be very large | |||||
| EXPECT_EQ(ds->SetNumWorkers(INT32_MAX)->CreateIterator(), nullptr); | |||||
| } | |||||