Browse Source

add validate param to base class DatasetNode

tags/v1.1.0
Zirui Wu 5 years ago
parent
commit
0e2f7a9e9e
37 changed files with 69 additions and 6 deletions
  1. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc
  2. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc
  3. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc
  4. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc
  5. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc
  6. +9
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  7. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  8. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc
  9. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc
  10. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc
  11. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc
  12. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc
  13. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc
  14. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc
  15. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc
  16. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc
  17. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc
  18. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc
  19. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc
  20. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc
  21. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc
  22. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc
  23. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc
  24. +4
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc
  25. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc
  26. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc
  27. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc
  28. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc
  29. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc
  30. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc
  31. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc
  32. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc
  33. +4
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc
  34. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc
  35. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc
  36. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc
  37. +17
    -0
      tests/ut/cpp/dataset/c_api_dataset_ops_test.cc

+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc View File

@@ -70,6 +70,7 @@ void BatchNode::Print(std::ostream &out) const {
}

Status BatchNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (batch_size_ <= 0) {
std::string err_msg = "BatchNode: batch_size should be positive integer, but got: " + std::to_string(batch_size_);
MS_LOG(ERROR) << err_msg;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc View File

@@ -94,6 +94,7 @@ Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *n
}

Status BucketBatchByLengthNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (element_length_function_ == nullptr && column_names_.size() != 1) {
std::string err_msg =
"BucketBatchByLengthNode: when element_length_function is not specified, size of column_name must be 1 but is: " +


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc View File

@@ -63,6 +63,7 @@ Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *no
}

Status BuildSentenceVocabNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (vocab_ == nullptr) {
std::string err_msg = "BuildSentenceVocabNode: vocab is null.";
MS_LOG(ERROR) << err_msg;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc View File

@@ -59,6 +59,7 @@ Status BuildVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
}

Status BuildVocabNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (vocab_ == nullptr) {
std::string err_msg = "BuildVocabNode: vocab is null.";
MS_LOG(ERROR) << err_msg;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc View File

@@ -49,6 +49,7 @@ std::shared_ptr<DatasetNode> ConcatNode::Copy() {
void ConcatNode::Print(std::ostream &out) const { out << Name(); }

Status ConcatNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (children_.size() < 2) {
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg;


+ 9
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -17,6 +17,7 @@
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

#include <algorithm>
#include <limits>
#include <memory>
#include <set>

@@ -220,7 +221,7 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
return shared_from_this();
}

DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) {
DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}), dataset_size_(-1) {
// Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();
@@ -418,6 +419,13 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
}
}
Status DatasetNode::ValidateParams() {
CHECK_FAIL_RETURN_UNEXPECTED(
num_workers_ > 0 && num_workers_ < std::numeric_limits<uint16_t>::max(),
Name() + "'s num_workers=" + std::to_string(num_workers_) + ", this value is less than 1 or too large.");

return Status::OK();
}

Status MappableSourceNode::Accept(IRNodePass *p, bool *modified) {
return p->Visit(shared_from_base<MappableSourceNode>(), modified);


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -150,9 +150,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Status Status::OK() if build successfully
virtual Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) = 0;

/// \brief Pure virtual function for derived class to implement parameters validation
/// \brief base virtual function for derived class to implement parameters validation
/// \return Status Status::OK() if all the parameters are valid
virtual Status ValidateParams() = 0;
virtual Status ValidateParams();

/// \brief Pure virtual function for derived class to get the shard id of specific node
/// \return Status Status::OK() if get shard id successfully
@@ -262,7 +262,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
std::vector<std::shared_ptr<DatasetNode>> children_;
DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase
std::shared_ptr<DatasetCache> cache_;
int64_t dataset_size_ = -1;
int64_t dataset_size_;
int32_t num_workers_;
int32_t rows_per_buffer_;
int32_t connector_que_size_;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc View File

@@ -48,6 +48,7 @@ Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {

// Function to validate the parameters for EpochCtrlNode
Status EpochCtrlNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (num_epochs_ <= 0 && num_epochs_ != -1) {
std::string err_msg =
"EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_);


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc View File

@@ -49,6 +49,7 @@ Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
}

Status FilterNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (predicate_ == nullptr) {
std::string err_msg = "FilterNode: predicate is not specified.";
MS_LOG(ERROR) << err_msg;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc View File

@@ -82,6 +82,7 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
}

Status MapNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (operations_.empty()) {
std::string err_msg = "MapNode: No operation is specified.";
MS_LOG(ERROR) << err_msg;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc View File

@@ -40,6 +40,7 @@ std::shared_ptr<DatasetNode> ProjectNode::Copy() {
void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; }

Status ProjectNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (columns_.empty()) {
std::string err_msg = "ProjectNode: No columns are specified.";
MS_LOG(ERROR) << err_msg;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc View File

@@ -43,6 +43,7 @@ void RenameNode::Print(std::ostream &out) const {
}

Status RenameNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (input_columns_.size() != output_columns_.size()) {
std::string err_msg = "RenameNode: input and output columns must be the same size";
MS_LOG(ERROR) << err_msg;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc View File

@@ -43,6 +43,7 @@ Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
}

Status RepeatNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (repeat_count_ <= 0 && repeat_count_ != -1) {
std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " +
std::to_string(repeat_count_);


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc View File

@@ -49,6 +49,7 @@ Status RootNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {

// Function to validate the parameters for RootNode
Status RootNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (num_epochs_ <= 0 && num_epochs_ != -1) {
std::string err_msg =
"RootNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_);


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc View File

@@ -51,6 +51,7 @@ Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {

// Function to validate the parameters for ShuffleNode
Status ShuffleNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (shuffle_size_ <= 1) {
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_);
MS_LOG(ERROR) << err_msg;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc View File

@@ -45,6 +45,7 @@ Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {

// Function to validate the parameters for SkipNode
Status SkipNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (skip_count_ <= -1) {
std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_);
MS_LOG(ERROR) << err_msg;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc View File

@@ -50,6 +50,7 @@ void AlbumNode::Print(std::ostream &out) const {
}

Status AlbumNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_));

RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumNode", {schema_path_}));


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc View File

@@ -50,6 +50,7 @@ void CelebANode::Print(std::ostream &out) const {
}

Status CelebANode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_));

RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_));


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc View File

@@ -43,6 +43,7 @@ void Cifar100Node::Print(std::ostream &out) const {
}

Status Cifar100Node::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_));

RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_));


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc View File

@@ -43,6 +43,7 @@ void Cifar10Node::Print(std::ostream &out) const {
}

Status Cifar10Node::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_));

RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_));


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc View File

@@ -53,6 +53,7 @@ void CLUENode::Print(std::ostream &out) const {
}

Status CLUENode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_));

RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}));


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc View File

@@ -46,6 +46,7 @@ std::shared_ptr<DatasetNode> CocoNode::Copy() {
void CocoNode::Print(std::ostream &out) const { out << Name(); }

Status CocoNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_));

RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoNode", sampler_));


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc View File

@@ -61,6 +61,7 @@ void CSVNode::Print(std::ostream &out) const {
}

Status CSVNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_));

if (field_delim_ == '"' || field_delim_ == '\r' || field_delim_ == '\n') {


+ 4
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc View File

@@ -83,7 +83,10 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
}

// no validation is needed for generator op.
Status GeneratorNode::ValidateParams() { return Status::OK(); }
Status GeneratorNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
return Status::OK();
}

Status GeneratorNode::GetShardId(int32_t *shard_id) {
RETURN_UNEXPECTED_IF_NULL(shard_id);


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc View File

@@ -53,6 +53,7 @@ void ImageFolderNode::Print(std::ostream &out) const {
}

Status ImageFolderNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_));

RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderNode", sampler_));


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc View File

@@ -57,6 +57,7 @@ void ManifestNode::Print(std::ostream &out) const {
}

Status ManifestNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
for (char c : dataset_file_) {
auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc View File

@@ -67,6 +67,7 @@ std::shared_ptr<DatasetNode> MindDataNode::Copy() {
void MindDataNode::Print(std::ostream &out) const { out << Name() + "(file:" + dataset_file_ + ",...)"; }

Status MindDataNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (!search_for_pattern_ && dataset_files_.size() > 4096) {
std::string err_msg =
"MindDataNode: length of dataset_file must be less than or equal to 4096, dataset_file length: " +


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc View File

@@ -40,6 +40,7 @@ std::shared_ptr<DatasetNode> MnistNode::Copy() {
void MnistNode::Print(std::ostream &out) const { out << Name(); }

Status MnistNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_));

RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_));


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc View File

@@ -41,6 +41,7 @@ void RandomNode::Print(std::ostream &out) const { out << Name() + "(num_row:" +

// ValidateParams for RandomNode
Status RandomNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (total_rows_ < 0) {
std::string err_msg =
"RandomNode: total_rows must be greater than or equal 0, now get " + std::to_string(total_rows_);


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc View File

@@ -55,6 +55,7 @@ void TextFileNode::Print(std::ostream &out) const {
}

Status TextFileNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_));

if (num_samples_ < 0) {


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc View File

@@ -49,6 +49,7 @@ void TFRecordNode::Print(std::ostream &out) const {

// Validator for TFRecordNode
Status TFRecordNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (dataset_files_.empty()) {
std::string err_msg = "TFRecordNode: dataset_files is not specified.";
MS_LOG(ERROR) << err_msg;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc View File

@@ -49,6 +49,7 @@ std::shared_ptr<DatasetNode> VOCNode::Copy() {
void VOCNode::Print(std::ostream &out) const { out << Name(); }

Status VOCNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
Path dir(dataset_dir_);

RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_));


+ 4
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc View File

@@ -52,7 +52,10 @@ Status SyncWaitNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
}

// Function to validate the parameters for SyncWaitNode
Status SyncWaitNode::ValidateParams() { return Status::OK(); }
Status SyncWaitNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc View File

@@ -46,6 +46,7 @@ Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {

// Function to validate the parameters for TakeNode
Status TakeNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (take_count_ <= 0 && take_count_ != -1) {
std::string err_msg =
"TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc View File

@@ -57,6 +57,7 @@ void TransferNode::Print(std::ostream &out) const {

// Validator for TransferNode
Status TransferNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (total_batch_ < 0) {
std::string err_msg = "TransferNode: Total batches should be >= 0, value given: ";
MS_LOG(ERROR) << err_msg << total_batch_;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc View File

@@ -41,6 +41,7 @@ std::shared_ptr<DatasetNode> ZipNode::Copy() {
void ZipNode::Print(std::ostream &out) const { out << Name(); }

Status ZipNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (children_.size() < 2) {
std::string err_msg = "ZipNode: input datasets are not specified.";
MS_LOG(ERROR) << err_msg;


+ 17
- 0
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc View File

@@ -1833,3 +1833,20 @@ TEST_F(MindDataTestPipeline, TestZipSuccess2) {
// Manually terminate the pipeline
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestNumWorkersValidate) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNumWorkersValidate.";

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path);

// ds needs to be non nullptr otherwise, the subsequent logic will core dump
ASSERT_NE(ds, nullptr);

// test if set num_workers=-1
EXPECT_EQ(ds->SetNumWorkers(-1)->CreateIterator(), nullptr);

// test if set num_workers can be very large
EXPECT_EQ(ds->SetNumWorkers(INT32_MAX)->CreateIterator(), nullptr);
}

Loading…
Cancel
Save