Browse Source

!979 dataset: repair parameter check in rename Op

Merge pull request !979 from ms_yan/rename_columns
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
6c79c00a3f
2 changed files with 9 additions and 6 deletions
  1. +1
    -6
      mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc
  2. +8
    -0
      mindspore/dataset/engine/validators.py

+ 1
- 6
mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc View File

@@ -51,12 +51,7 @@ Status RenameOp::Builder::Build(std::shared_ptr<RenameOp> *ptr) {
// constructor
RenameOp::RenameOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
int32_t op_connector_size)
: PipelineOp(op_connector_size), in_columns_(in_col_names), out_columns_(out_col_names) {
// check input & output sizes
if (in_columns_.size() != out_columns_.size()) {
MS_LOG(ERROR) << "Rename operator number of in columns != number of out columns.";
}
}
: PipelineOp(op_connector_size), in_columns_(in_col_names), out_columns_(out_col_names) {}

// destructor
RenameOp::~RenameOp() {}


+ 8
- 0
mindspore/dataset/engine/validators.py View File

@@ -884,6 +884,14 @@ def check_rename(method):
raise ValueError("{} is not provided.".format(param_name))
check_columns(param, param_name)

input_size, output_size = 1, 1
if isinstance(param_dict.get(req_param_columns[0]), list):
input_size = len(param_dict.get(req_param_columns[0]))
if isinstance(param_dict.get(req_param_columns[1]), list):
output_size = len(param_dict.get(req_param_columns[1]))
if input_size != output_size:
raise ValueError("Number of column in input_columns and output_columns is not equal.")

return method(*args, **kwargs)

return new_method


Loading…
Cancel
Save