diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index 60adddb4a8..0615fdcbbc 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -15,14 +15,14 @@ */ #include "dataset/engine/datasetops/source/tf_reader_op.h" -#include -#include +#include #include #include #include #include +#include #include -#include +#include #include "proto/example.pb.h" #include "./securec.h" @@ -905,7 +905,7 @@ Status TFReaderOp::LoadIntList(const ColDescriptor ¤t_col, const dataengin return Status::OK(); } -Status TFReaderOp::CreateSchema(const std::string tf_file, const std::vector &columns_to_load) { +Status TFReaderOp::CreateSchema(const std::string tf_file, std::vector columns_to_load) { std::ifstream reader; reader.open(tf_file); @@ -926,12 +926,14 @@ Status TFReaderOp::CreateSchema(const std::string tf_file, const std::vector &feature_map = example_features.feature(); - std::vector columns = columns_to_load; - if (columns_to_load.empty()) - (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns), + if (columns_to_load.empty()) { + (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns_to_load), [](const auto &it) -> std::string { return it.first; }); - for (const auto &curr_col_name : columns) { + std::sort(columns_to_load.begin(), columns_to_load.end()); + } + + for (const auto &curr_col_name : columns_to_load) { auto it = feature_map.find(curr_col_name); if (it == feature_map.end()) { RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h index 3dc5ee932e..17a76f2c3d 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h @@ -335,7 +335,7 @@ class TFReaderOp : public ParallelOp { // Reads one row of data from a tf file and creates a schema based on that row // @return Status - the error code returned. - Status CreateSchema(const std::string tf_file, const std::vector &columns_to_load); + Status CreateSchema(const std::string tf_file, std::vector columns_to_load); // Meant to be called async. Will read files in the range [begin, end) and return the total rows // @param filenames - a list of tf data filenames.