diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc index 40db6edb9f..dfe2190125 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc @@ -70,6 +70,7 @@ void BatchNode::Print(std::ostream &out) const { } Status BatchNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (batch_size_ <= 0) { std::string err_msg = "BatchNode: batch_size should be positive integer, but got: " + std::to_string(batch_size_); MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc index 4bbd4316a7..c1cea3a8c1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc @@ -94,6 +94,7 @@ Status BucketBatchByLengthNode::Build(std::vector> *n } Status BucketBatchByLengthNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (element_length_function_ == nullptr && column_names_.size() != 1) { std::string err_msg = "BucketBatchByLengthNode: when element_length_function is not specified, size of column_name must be 1 but is: " + diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc index b012e72223..4bfcdc5576 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc @@ -63,6 +63,7 @@ Status BuildSentenceVocabNode::Build(std::vector> *no } Status BuildSentenceVocabNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (vocab_ == nullptr) { std::string err_msg = "BuildSentenceVocabNode: vocab is null."; MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc index 714f967d75..a4041ac7e3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc @@ -59,6 +59,7 @@ Status BuildVocabNode::Build(std::vector> *node_ops) } Status BuildVocabNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (vocab_ == nullptr) { std::string err_msg = "BuildVocabNode: vocab is null."; MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc index 8ff227fa5d..bf72287682 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc @@ -49,6 +49,7 @@ std::shared_ptr ConcatNode::Copy() { void ConcatNode::Print(std::ostream &out) const { out << Name(); } Status ConcatNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (children_.size() < 2) { std::string err_msg = "ConcatNode: concatenated datasets are not specified."; MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index 242a7fcf8f..dbb8989619 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -17,6 +17,7 @@ #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include +#include #include #include @@ -220,7 +221,7 @@ std::shared_ptr DatasetNode::SetNumWorkers(int32_t num_workers) { return shared_from_this(); } -DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) { +DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}), dataset_size_(-1) { // Fetch some default value from config manager std::shared_ptr cfg = GlobalContext::config_manager(); num_workers_ = cfg->num_parallel_workers(); @@ -418,6 +419,13 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr &siz RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); } } +Status DatasetNode::ValidateParams() { + CHECK_FAIL_RETURN_UNEXPECTED( + num_workers_ > 0 && num_workers_ < std::numeric_limits::max(), + Name() + "'s num_workers=" + std::to_string(num_workers_) + ", this value is less than 1 or too large."); + + return Status::OK(); +} Status MappableSourceNode::Accept(IRNodePass *p, bool *modified) { return p->Visit(shared_from_base(), modified); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 5b5e78a311..a4c3459d80 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -150,9 +150,9 @@ class DatasetNode : public std::enable_shared_from_this { /// \return Status Status::OK() if build successfully virtual Status Build(std::vector> *node_ops) = 0; - /// \brief Pure virtual function for derived class to implement parameters validation + /// \brief base virtual function for derived class to implement parameters validation /// \return Status Status::OK() if all the parameters are valid - virtual Status ValidateParams() = 0; + virtual Status ValidateParams(); /// \brief Pure virtual function for derived class to get the shard id of specific node /// \return Status Status::OK() if get shard id successfully @@ -262,7 +262,7 @@ class DatasetNode : public std::enable_shared_from_this { std::vector> children_; DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase std::shared_ptr cache_; - int64_t dataset_size_ = -1; + int64_t dataset_size_; int32_t num_workers_; int32_t rows_per_buffer_; int32_t connector_que_size_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc index 2e3f473cea..8bea432e5b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc @@ -48,6 +48,7 @@ Status EpochCtrlNode::Build(std::vector> *node_ops) { // Function to validate the parameters for EpochCtrlNode Status EpochCtrlNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (num_epochs_ <= 0 && num_epochs_ != -1) { std::string err_msg = "EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc index 449371798a..2da7e15b39 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc @@ -49,6 +49,7 @@ Status FilterNode::Build(std::vector> *node_ops) { } Status FilterNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (predicate_ == nullptr) { std::string err_msg = "FilterNode: predicate is not specified."; MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc index 71532b4590..e010e8b498 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc @@ -82,6 +82,7 @@ Status MapNode::Build(std::vector> *node_ops) { } Status MapNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (operations_.empty()) { std::string err_msg = "MapNode: No operation is specified."; MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc index c9fcc3cefe..66b8acc1dd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc @@ -40,6 +40,7 @@ std::shared_ptr ProjectNode::Copy() { void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; } Status ProjectNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (columns_.empty()) { std::string err_msg = "ProjectNode: No columns are specified."; MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc index 3f162d76e8..31cc7ca1e4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc @@ -43,6 +43,7 @@ void RenameNode::Print(std::ostream &out) const { } Status RenameNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (input_columns_.size() != output_columns_.size()) { std::string err_msg = "RenameNode: input and output columns must be the same size"; MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc index 8da4ee6392..adc0f05356 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc @@ -43,6 +43,7 @@ Status RepeatNode::Build(std::vector> *node_ops) { } Status RepeatNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (repeat_count_ <= 0 && repeat_count_ != -1) { std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " + std::to_string(repeat_count_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc index b0375e2791..31123b709e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc @@ -49,6 +49,7 @@ Status RootNode::Build(std::vector> *node_ops) { // Function to validate the parameters for RootNode Status RootNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (num_epochs_ <= 0 && num_epochs_ != -1) { std::string err_msg = "RootNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc index 6930d76eb7..b5bdc203ec 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc @@ -51,6 +51,7 @@ Status ShuffleNode::Build(std::vector> *node_ops) { // Function to validate the parameters for ShuffleNode Status ShuffleNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (shuffle_size_ <= 1) { std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_); MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc index 32100c3ca1..b2db03902c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc @@ -45,6 +45,7 @@ Status SkipNode::Build(std::vector> *node_ops) { // Function to validate the parameters for SkipNode Status SkipNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (skip_count_ <= -1) { std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_); MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc index be8ee97286..924c351c10 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc @@ -50,6 +50,7 @@ void AlbumNode::Print(std::ostream &out) const { } Status AlbumNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumNode", {schema_path_})); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index 8a099de335..cd5638c163 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -50,6 +50,7 @@ void CelebANode::Print(std::ostream &out) const { } Status CelebANode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index c901e3a832..c4e1643c49 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -43,6 +43,7 @@ void Cifar100Node::Print(std::ostream &out) const { } Status Cifar100Node::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index 1e4f45dd42..979893bb90 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -43,6 +43,7 @@ void Cifar10Node::Print(std::ostream &out) const { } Status Cifar10Node::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index 6f956fe960..570962d5b8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -53,6 +53,7 @@ void CLUENode::Print(std::ostream &out) const { } Status CLUENode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"})); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index 49b9c6fea7..5d9953d8b7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -46,6 +46,7 @@ std::shared_ptr CocoNode::Copy() { void CocoNode::Print(std::ostream &out) const { out << Name(); } Status CocoNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoNode", sampler_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc index 74ce569fba..b8dc3243e9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc @@ -61,6 +61,7 @@ void CSVNode::Print(std::ostream &out) const { } Status CSVNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_)); if (field_delim_ == '"' || field_delim_ == '\r' || field_delim_ == '\n') { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc index 9c8a7ad86a..5c7bed1109 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc @@ -83,7 +83,10 @@ Status GeneratorNode::Build(std::vector> *node_ops) { } // no validation is needed for generator op. -Status GeneratorNode::ValidateParams() { return Status::OK(); } +Status GeneratorNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + return Status::OK(); +} Status GeneratorNode::GetShardId(int32_t *shard_id) { RETURN_UNEXPECTED_IF_NULL(shard_id); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index fa6e8287d3..0448672864 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -53,6 +53,7 @@ void ImageFolderNode::Print(std::ostream &out) const { } Status ImageFolderNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderNode", sampler_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index c63052a71f..7b033efbf5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -57,6 +57,7 @@ void ManifestNode::Print(std::ostream &out) const { } Status ManifestNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); std::vector forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; for (char c : dataset_file_) { auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc index 96ffac6389..d7baf79dd3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc @@ -67,6 +67,7 @@ std::shared_ptr MindDataNode::Copy() { void MindDataNode::Print(std::ostream &out) const { out << Name() + "(file:" + dataset_file_ + ",...)"; } Status MindDataNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (!search_for_pattern_ && dataset_files_.size() > 4096) { std::string err_msg = "MindDataNode: length of dataset_file must be less than or equal to 4096, dataset_file length: " + diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc index 7766561df2..1b7471a53f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc @@ -40,6 +40,7 @@ std::shared_ptr MnistNode::Copy() { void MnistNode::Print(std::ostream &out) const { out << Name(); } Status MnistNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc index 70910746dc..27c3743d2e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc @@ -41,6 +41,7 @@ void RandomNode::Print(std::ostream &out) const { out << Name() + "(num_row:" + // ValidateParams for RandomNode Status RandomNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (total_rows_ < 0) { std::string err_msg = "RandomNode: total_rows must be greater than or equal 0, now get " + std::to_string(total_rows_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc index 5c856d89e7..9e168b94b0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc @@ -55,6 +55,7 @@ void TextFileNode::Print(std::ostream &out) const { } Status TextFileNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_)); if (num_samples_ < 0) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index 66558bb285..a03aa77e1e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -49,6 +49,7 @@ void TFRecordNode::Print(std::ostream &out) const { // Validator for TFRecordNode Status TFRecordNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (dataset_files_.empty()) { std::string err_msg = "TFRecordNode: dataset_files is not specified."; MS_LOG(ERROR) << err_msg; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index 7a8fd8f2bf..2a66e29bea 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -49,6 +49,7 @@ std::shared_ptr VOCNode::Copy() { void VOCNode::Print(std::ostream &out) const { out << Name(); } Status VOCNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); Path dir(dataset_dir_); RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc index 5886da228a..7013acc7a8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc @@ -52,7 +52,10 @@ Status SyncWaitNode::Build(std::vector> *node_ops) { } // Function to validate the parameters for SyncWaitNode -Status SyncWaitNode::ValidateParams() { return Status::OK(); } +Status SyncWaitNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc index 5d6e12bd18..2b15ede536 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc @@ -46,6 +46,7 @@ Status TakeNode::Build(std::vector> *node_ops) { // Function to validate the parameters for TakeNode Status TakeNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (take_count_ <= 0 && take_count_ != -1) { std::string err_msg = "TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc index 1003f7de13..f41180a8d7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc @@ -57,6 +57,7 @@ void TransferNode::Print(std::ostream &out) const { // Validator for TransferNode Status TransferNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (total_batch_ < 0) { std::string err_msg = "TransferNode: Total batches should be >= 0, value given: "; MS_LOG(ERROR) << err_msg << total_batch_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc index 03f89a4b44..2c525ca179 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc @@ -41,6 +41,7 @@ std::shared_ptr ZipNode::Copy() { void ZipNode::Print(std::ostream &out) const { out << Name(); } Status ZipNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); if (children_.size() < 2) { std::string err_msg = "ZipNode: input datasets are not specified."; MS_LOG(ERROR) << err_msg; diff --git a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc index 0006a64444..880f715cd9 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc @@ -1833,3 +1833,20 @@ TEST_F(MindDataTestPipeline, TestZipSuccess2) { // Manually terminate the pipeline iter->Stop(); } + +TEST_F(MindDataTestPipeline, TestNumWorkersValidate) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNumWorkersValidate."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path); + + // ds needs to be non nullptr otherwise, the subsequent logic will core dump + ASSERT_NE(ds, nullptr); + + // test if set num_workers=-1 + EXPECT_EQ(ds->SetNumWorkers(-1)->CreateIterator(), nullptr); + + // test if set num_workers can be very large + EXPECT_EQ(ds->SetNumWorkers(INT32_MAX)->CreateIterator(), nullptr); +} \ No newline at end of file