|
|
@@ -221,7 +221,7 @@ std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::s |
|
|
// Function to overload "+" operator to concat two datasets |
|
|
// Function to overload "+" operator to concat two datasets |
|
|
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1, |
|
|
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1, |
|
|
const std::shared_ptr<Dataset> &datasets2) { |
|
|
const std::shared_ptr<Dataset> &datasets2) { |
|
|
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2})); |
|
|
|
|
|
|
|
|
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets2, datasets1})); |
|
|
|
|
|
|
|
|
// Call derived class validation method. |
|
|
// Call derived class validation method. |
|
|
return ds->ValidateParams() ? ds : nullptr; |
|
|
return ds->ValidateParams() ? ds : nullptr; |
|
|
@@ -1592,6 +1592,10 @@ bool ConcatDataset::ValidateParams() { |
|
|
MS_LOG(ERROR) << "Concat: concatenated datasets are not specified."; |
|
|
MS_LOG(ERROR) << "Concat: concatenated datasets are not specified."; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { |
|
|
|
|
|
MS_LOG(ERROR) << "Concat: concatenated dataset should not be null."; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -1676,6 +1680,16 @@ bool RenameDataset::ValidateParams() { |
|
|
MS_LOG(ERROR) << "input and output columns must be the same size"; |
|
|
MS_LOG(ERROR) << "input and output columns must be the same size"; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
for (uint32_t i = 0; i < input_columns_.size(); ++i) { |
|
|
|
|
|
if (input_columns_[i].empty()) { |
|
|
|
|
|
MS_LOG(ERROR) << "input_columns: column name should not be empty."; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
if (output_columns_[i].empty()) { |
|
|
|
|
|
MS_LOG(ERROR) << "output_columns: column name should not be empty."; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -1766,7 +1780,7 @@ std::vector<std::shared_ptr<DatasetOp>> TakeDataset::Build() { |
|
|
|
|
|
|
|
|
// Function to validate the parameters for TakeDataset |
|
|
// Function to validate the parameters for TakeDataset |
|
|
bool TakeDataset::ValidateParams() { |
|
|
bool TakeDataset::ValidateParams() { |
|
|
if (take_count_ < 0 && take_count_ != -1) { |
|
|
|
|
|
|
|
|
if (take_count_ <= 0 && take_count_ != -1) { |
|
|
MS_LOG(ERROR) << "Take: take_count should be either -1 or positive integer, take_count: " << take_count_; |
|
|
MS_LOG(ERROR) << "Take: take_count should be either -1 or positive integer, take_count: " << take_count_; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|