Browse Source

C++ API: Cleanup datasets.cc return Status support

tags/v1.1.0
Cathy Wong 5 years ago
parent
commit
c2c7798738
1 changed files with 7 additions and 149 deletions
  1. +7
    -149
      mindspore/ccsrc/minddata/dataset/api/datasets.cc

+ 7
- 149
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -795,25 +795,14 @@ AlbumDataset::AlbumDataset(const std::string &dataset_dir, const std::string &da
sampler_(sampler) {}

Status AlbumDataset::ValidateParams() {
Status rc;

RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumDataset", dataset_dir_));

RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumDataset", {schema_path_}));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateDatasetSampler("AlbumDataset", sampler_));
if (rc.IsError()) {
return rc;
}

if (!column_names_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AlbumDataset", "column_names", column_names_));
if (rc.IsError()) {
return rc;
}
}

return Status::OK();
@@ -842,22 +831,11 @@ CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler), decode_(decode), extensions_(extensions) {}

Status CelebADataset::ValidateParams() {
Status rc;

RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebADataset", dataset_dir_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebADataset", sampler_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"all", "train", "valid", "test"}));
if (rc.IsError()) {
return rc;
}

return Status::OK();
}
@@ -885,22 +863,11 @@ Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, const std::string
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}

Status Cifar10Dataset::ValidateParams() {
Status rc;

RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Dataset", sampler_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"}));
if (rc.IsError()) {
return rc;
}

return Status::OK();
}
@@ -929,22 +896,11 @@ Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, const std::stri
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}

Status Cifar100Dataset::ValidateParams() {
Status rc;

RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Dataset", sampler_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"}));
if (rc.IsError()) {
return rc;
}

return Status::OK();
}
@@ -981,12 +937,7 @@ CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string
shard_id_(shard_id) {}

Status CLUEDataset::ValidateParams() {
Status rc;

RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUEDataset", dataset_files_));
if (rc.IsError()) {
return rc;
}

std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"};
std::vector<std::string> usage_list = {"train", "test", "eval"};
@@ -1010,9 +961,6 @@ Status CLUEDataset::ValidateParams() {
}

RETURN_IF_NOT_OK(ValidateDatasetShardParams("CLUEDataset", num_shards_, shard_id_));
if (rc.IsError()) {
return rc;
}

return Status::OK();
}
@@ -1173,17 +1121,9 @@ CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &anno
: dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {}

Status CocoDataset::ValidateParams() {
Status rc;

RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoDataset", dataset_dir_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoDataset", sampler_));
if (rc.IsError()) {
return rc;
}

Path annotation_file(annotation_file_);
if (!annotation_file.Exists()) {
@@ -1279,17 +1219,7 @@ CSVDataset::CSVDataset(const std::vector<std::string> &csv_files, char field_del
shard_id_(shard_id) {}

Status CSVDataset::ValidateParams() {
Status rc;

RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUEDataset", dataset_files_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVDataset", dataset_files_));
if (rc.IsError()) {
return rc;
}

if (field_delim_ == '"' || field_delim_ == '\r' || field_delim_ == '\n') {
std::string err_msg = "CSVDataset: The field delimiter should not be \", \\r, \\n";
@@ -1304,9 +1234,6 @@ Status CSVDataset::ValidateParams() {
}

RETURN_IF_NOT_OK(ValidateDatasetShardParams("CSVDataset", num_shards_, shard_id_));
if (rc.IsError()) {
return rc;
}

if (find(column_defaults_.begin(), column_defaults_.end(), nullptr) != column_defaults_.end()) {
std::string err_msg = "CSVDataset: column_default should not be null.";
@@ -1316,9 +1243,6 @@ Status CSVDataset::ValidateParams() {

if (!column_names_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("CSVDataset", "column_names", column_names_));
if (rc.IsError()) {
return rc;
}
}

return Status::OK();
@@ -1382,17 +1306,9 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std
exts_(extensions) {}

Status ImageFolderDataset::ValidateParams() {
Status rc;

RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderDataset", dataset_dir_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderDataset", sampler_));
if (rc.IsError()) {
return rc;
}

return Status::OK();
}
@@ -1422,8 +1338,6 @@ ManifestDataset::ManifestDataset(const std::string &dataset_file, const std::str
: dataset_file_(dataset_file), usage_(usage), decode_(decode), class_index_(class_indexing), sampler_(sampler) {}

Status ManifestDataset::ValidateParams() {
Status rc;

std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
for (char c : dataset_file_) {
auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);
@@ -1442,9 +1356,6 @@ Status ManifestDataset::ValidateParams() {
}

RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestDataset", sampler_));
if (rc.IsError()) {
return rc;
}

std::vector<std::string> usage_list = {"train", "eval", "inference"};
if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) {
@@ -1481,22 +1392,11 @@ MnistDataset::MnistDataset(std::string dataset_dir, std::string usage, std::shar
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}

Status MnistDataset::ValidateParams() {
Status rc;

RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistDataset", dataset_dir_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistDataset", sampler_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"}));
if (rc.IsError()) {
return rc;
}

return Status::OK();
}
@@ -1519,8 +1419,6 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {

// ValideParams for RandomDataset
Status RandomDataset::ValidateParams() {
Status rc;

if (total_rows_ < 0) {
std::string err_msg = "RandomDataset: total_rows must be greater than or equal 0, now get " + total_rows_;
MS_LOG(ERROR) << err_msg;
@@ -1528,15 +1426,9 @@ Status RandomDataset::ValidateParams() {
}

RETURN_IF_NOT_OK(ValidateDatasetSampler("RandomDataset", sampler_));
if (rc.IsError()) {
return rc;
}

if (!columns_list_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomDataset", "columns_list", columns_list_));
if (rc.IsError()) {
return rc;
}
}

return Status::OK();
@@ -1603,12 +1495,7 @@ TextFileDataset::TextFileDataset(std::vector<std::string> dataset_files, int32_t
shard_id_(shard_id) {}

Status TextFileDataset::ValidateParams() {
Status rc;

RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUEDataset", dataset_files_));
if (rc.IsError()) {
return rc;
}

if (num_samples_ < 0) {
std::string err_msg = "TextFileDataset: Invalid number of samples: " + num_samples_;
@@ -1617,9 +1504,6 @@ Status TextFileDataset::ValidateParams() {
}

RETURN_IF_NOT_OK(ValidateDatasetShardParams("TextFileDataset", num_shards_, shard_id_));
if (rc.IsError()) {
return rc;
}

return Status::OK();
}
@@ -1728,8 +1612,6 @@ VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task,
sampler_(sampler) {}

Status VOCDataset::ValidateParams() {
Status rc;

Path dir(dataset_dir_);
if (!dir.IsDirectory()) {
std::string err_msg = "Invalid dataset path or no dataset path is specified.";
@@ -1738,9 +1620,6 @@ Status VOCDataset::ValidateParams() {
}

RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCDataset", sampler_));
if (rc.IsError()) {
return rc;
}

if (task_ == "Segmentation") {
if (!class_index_.empty()) {
@@ -1953,29 +1832,29 @@ std::vector<std::shared_ptr<DatasetOp>> BuildVocabDataset::Build() {
}

Status BuildVocabDataset::ValidateParams() {
Status rc;
if (vocab_ == nullptr) {
std::string err_msg = "BuildVocab: vocab is null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (top_k_ <= 0) {
std::string err_msg = "BuildVocab: top_k should be positive, but got: " + top_k_;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) {
std::string err_msg = "BuildVocab: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)";
MS_LOG(ERROR) << "BuildVocab: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), "
<< "but got [" << freq_range_.first << ", " << freq_range_.second << "]";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (!columns_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocab", "columns", columns_));
if (rc.IsError()) {
return rc;
}
}

return Status::OK();
}
#endif
@@ -2039,8 +1918,6 @@ std::vector<std::shared_ptr<DatasetOp>> MapDataset::Build() {
}

Status MapDataset::ValidateParams() {
Status rc;

if (operations_.empty()) {
std::string err_msg = "Map: No operation is specified.";
MS_LOG(ERROR) << err_msg;
@@ -2049,22 +1926,16 @@ Status MapDataset::ValidateParams() {

if (!input_columns_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapDataset", "input_columns", input_columns_));
if (rc.IsError()) {
return rc;
}
}

if (!output_columns_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapDataset", "output_columns", output_columns_));
if (rc.IsError()) {
return rc;
}
}

if (!project_columns_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapDataset", "project_columns", project_columns_));
if (rc.IsError()) {
return rc;
}
}

return Status::OK();
}

@@ -2072,8 +1943,6 @@ Status MapDataset::ValidateParams() {
ProjectDataset::ProjectDataset(const std::vector<std::string> &columns) : columns_(columns) {}

Status ProjectDataset::ValidateParams() {
Status rc;

if (columns_.empty()) {
std::string err_msg = "ProjectDataset: No columns are specified.";
MS_LOG(ERROR) << err_msg;
@@ -2081,9 +1950,6 @@ Status ProjectDataset::ValidateParams() {
}

RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectDataset", "columns", columns_));
if (rc.IsError()) {
return rc;
}

return Status::OK();
}
@@ -2102,8 +1968,6 @@ RenameDataset::RenameDataset(const std::vector<std::string> &input_columns,
: input_columns_(input_columns), output_columns_(output_columns) {}

Status RenameDataset::ValidateParams() {
Status rc;

if (input_columns_.size() != output_columns_.size()) {
std::string err_msg = "RenameDataset: input and output columns must be the same size";
MS_LOG(ERROR) << err_msg;
@@ -2111,14 +1975,8 @@ Status RenameDataset::ValidateParams() {
}

RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameDataset", "input_columns", input_columns_));
if (rc.IsError()) {
return rc;
}

RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameDataset", "output_columns", output_columns_));
if (rc.IsError()) {
return rc;
}

return Status::OK();
}


Loading…
Cancel
Save