diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index c0e0519ced..bea29de70a 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -274,6 +274,25 @@ int64_t Dataset::GetNumClasses() { return rc.IsError() ? -1 : num_classes; } +std::vector Dataset::GetColumnNames() { + std::vector col_names; + auto ds = shared_from_this(); + Status rc; + std::unique_ptr runtime_context = std::make_unique(); + rc = runtime_context->Init(); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetColumnNames: Initializing RuntimeContext failed."; + return std::vector(); + } + rc = tree_getters_->Init(ds->IRNode()); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetColumnNames: Initializing TreeGetters failed."; + return std::vector(); + } + rc = tree_getters_->GetColumnNames(&col_names); + return rc.IsError() ? std::vector() : col_names; +} + std::vector>> Dataset::GetClassIndexing() { std::vector>> output_class_indexing; auto ds = shared_from_this(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index 58863a89c4..13ac671bab 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -466,6 +466,22 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) { return Status::OK(); } +Status TreeGetters::GetColumnNames(std::vector *output) { + std::shared_ptr root = std::shared_ptr(tree_adapter_->GetRoot()); + std::unordered_map column_name_id_map = root->column_name_id_map(); + if (column_name_id_map.empty()) RETURN_STATUS_UNEXPECTED("GetColumnNames: column_name_id map was empty."); + std::vector> column_name_id_vector(column_name_id_map.begin(), + column_name_id_map.end()); + std::sort(column_name_id_vector.begin(), column_name_id_vector.end(), + [](const std::pair &a, const std::pair &b) { + return a.second < b.second; + }); + for (auto item : column_name_id_vector) { + (*output).push_back(item.first); + } + return Status::OK(); +} + Status TreeGetters::GetClassIndexing(std::vector>> *output_class_indexing) { std::shared_ptr root = std::shared_ptr(tree_adapter_->GetRoot()); CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index 67d7dec513..7ee5c68720 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -166,6 +166,7 @@ class TreeGetters : public TreeConsumer { Status GetBatchSize(int64_t *batch_size); Status GetRepeatCount(int64_t *repeat_count); Status GetNumClasses(int64_t *num_classes); + Status GetColumnNames(std::vector *output); Status GetClassIndexing(std::vector>> *output_class_indexing); bool isInitialized(); std::string Name() override { return "TreeGetters"; } diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index ccdaf22f54..2862d2e2d4 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -119,6 +119,10 @@ class Dataset : public std::enable_shared_from_this { /// \return number of classes. If failed, return -1 int64_t GetNumClasses(); + /// \brief Gets the column names + /// \return Names of the columns. If failed, return an empty vector + std::vector GetColumnNames(); + /// \brief Gets the class indexing /// \return a map of ClassIndexing. If failed, return an empty map std::vector>> GetClassIndexing(); diff --git a/tests/ut/cpp/dataset/c_api_dataset_album_test.cc b/tests/ut/cpp/dataset/c_api_dataset_album_test.cc index a4e1c26a7d..541ed9b47a 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_album_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_album_test.cc @@ -56,8 +56,8 @@ TEST_F(MindDataTestPipeline, TestAlbumBasic) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestAlbumgetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumgetters."; +TEST_F(MindDataTestPipeline, TestAlbumGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumGetters."; std::string folder_path = datasets_root_path_ + "/testAlbum/images"; std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json"; @@ -72,6 +72,7 @@ TEST_F(MindDataTestPipeline, TestAlbumgetters) { EXPECT_EQ(batch_size, 1); int64_t repeat_count = ds->GetRepeatCount(); EXPECT_EQ(repeat_count, 1); + EXPECT_EQ(ds->GetColumnNames(), column_names); } TEST_F(MindDataTestPipeline, TestAlbumDecode) { diff --git a/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc b/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc index 1c586d7f63..4045f840f4 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc @@ -81,6 +81,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Getters) { EXPECT_EQ(ds->GetDatasetSize(), 10000); std::vector types = ds->GetOutputTypes(); std::vector shapes = ds->GetOutputShapes(); + std::vector column_names = {"image", "label"}; int64_t num_classes = ds->GetNumClasses(); EXPECT_EQ(types.size(), 2); EXPECT_EQ(types[0].ToString(), "uint8"); @@ -97,6 +98,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Getters) { EXPECT_EQ(ds->GetOutputShapes(), shapes); EXPECT_EQ(ds->GetNumClasses(), -1); + EXPECT_EQ(ds->GetColumnNames(), column_names); EXPECT_EQ(ds->GetDatasetSize(), 10000); EXPECT_EQ(ds->GetOutputTypes(), types); EXPECT_EQ(ds->GetOutputShapes(), shapes); @@ -141,15 +143,32 @@ TEST_F(MindDataTestPipeline, TestCifar100Dataset) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestCifar100GetDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100GetDatasetSize."; +TEST_F(MindDataTestPipeline, TestCifar100Getters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Getters."; // Create a Cifar100 Dataset std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; std::shared_ptr ds = Cifar100(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); + std::vector column_names = {"image", "coarse_label", "fine_label"}; + std::vector types = ds->GetOutputTypes(); + std::vector shapes = ds->GetOutputShapes(); + int64_t num_classes = ds->GetNumClasses(); + + EXPECT_EQ(types.size(), 3); + EXPECT_EQ(types[0].ToString(), "uint8"); + EXPECT_EQ(types[1].ToString(), "uint32"); + EXPECT_EQ(types[2].ToString(), "uint32"); + EXPECT_EQ(shapes.size(), 3); + EXPECT_EQ(shapes[0].ToString(), "<32,32,3>"); + EXPECT_EQ(shapes[1].ToString(), "<>"); + EXPECT_EQ(shapes[2].ToString(), "<>"); + EXPECT_EQ(num_classes, -1); + EXPECT_EQ(ds->GetBatchSize(), 1); + EXPECT_EQ(ds->GetRepeatCount(), 1); EXPECT_EQ(ds->GetDatasetSize(), 10); + EXPECT_EQ(ds->GetColumnNames(), column_names); } TEST_F(MindDataTestPipeline, TestCifar100DatasetFail) { diff --git a/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc b/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc index bb9ead4d3f..7cce11d7bb 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc @@ -147,17 +147,19 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetBasic) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestCLUEGetDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEGetDatasetSize."; +TEST_F(MindDataTestPipeline, TestCLUEGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEGetters."; // Create a CLUEFile Dataset, with single CLUE file std::string clue_file = datasets_root_path_ + "/testCLUE/afqmc/train.json"; std::string task = "AFQMC"; std::string usage = "train"; std::shared_ptr ds = CLUE({clue_file}, task, usage, 2); + std::vector column_names = {"label", "sentence1", "sentence2"}; EXPECT_NE(ds, nullptr); EXPECT_EQ(ds->GetDatasetSize(), 2); + EXPECT_EQ(ds->GetColumnNames(), column_names); } TEST_F(MindDataTestPipeline, TestCLUEDatasetCMNLI) { diff --git a/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc b/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc index 5d92d697bb..0756eb52c0 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc @@ -61,8 +61,8 @@ TEST_F(MindDataTestPipeline, TestCocoDefault) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestCocoGetDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoGetDatasetSize."; +TEST_F(MindDataTestPipeline, TestCocoGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoGetters."; // Create a Coco Dataset std::string folder_path = datasets_root_path_ + "/testCOCO/train"; std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/train.json"; @@ -70,7 +70,9 @@ TEST_F(MindDataTestPipeline, TestCocoGetDatasetSize) { std::shared_ptr ds = Coco(folder_path, annotation_file); EXPECT_NE(ds, nullptr); + std::vector column_names = {"image", "bbox", "category_id", "iscrowd"}; EXPECT_EQ(ds->GetDatasetSize(), 6); + EXPECT_EQ(ds->GetColumnNames(), column_names); } TEST_F(MindDataTestPipeline, TestCocoDetection) { diff --git a/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc b/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc index af6fc24e53..b7697a437b 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc @@ -70,8 +70,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetBasic) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestCSVGetDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVGetDatasetSize."; +TEST_F(MindDataTestPipeline, TestCSVGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVGetters."; // Create a CSVDataset, with single CSV file std::string train_file = datasets_root_path_ + "/testCSV/1.csv"; @@ -80,6 +80,7 @@ TEST_F(MindDataTestPipeline, TestCSVGetDatasetSize) { EXPECT_NE(ds, nullptr); EXPECT_EQ(ds->GetDatasetSize(), 3); + EXPECT_EQ(ds->GetColumnNames(), column_names); } TEST_F(MindDataTestPipeline, TestCSVDatasetMultiFiles) { diff --git a/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc b/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc index 26820d7e12..f67e8bc73f 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc @@ -62,10 +62,12 @@ TEST_F(MindDataTestPipeline, TestManifestGetters) { // Create a Manifest Dataset std::shared_ptr ds1 = Manifest(file_path1); std::shared_ptr ds2 = Manifest(file_path2); + std::vector column_names = {"image", "label"}; EXPECT_NE(ds1, nullptr); EXPECT_EQ(ds1->GetDatasetSize(), 2); EXPECT_EQ(ds1->GetNumClasses(), 2); + EXPECT_EQ(ds1->GetColumnNames(), column_names); EXPECT_NE(ds2, nullptr); EXPECT_EQ(ds2->GetDatasetSize(), 4); diff --git a/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc b/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc index eca50b951f..f8d02ec873 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc @@ -57,8 +57,8 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess1) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestMindDataGetDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataGetDatasetSize with string file pattern."; +TEST_F(MindDataTestPipeline, TestMindDataGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataGetters with string file pattern."; // Create a MindData Dataset // Pass one mindrecord shard file to parse dataset info, and search for other mindrecord files with same dataset info, @@ -67,7 +67,10 @@ TEST_F(MindDataTestPipeline, TestMindDataGetDatasetSize) { std::shared_ptr ds = MindData(file_path); EXPECT_NE(ds, nullptr); + std::vector column_names = {"data", "file_name", "label"}; + EXPECT_EQ(ds->GetDatasetSize(), 20); + EXPECT_EQ(ds->GetColumnNames(), column_names); } TEST_F(MindDataTestPipeline, TestMindDataSuccess2) { diff --git a/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc b/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc index 59d017ec5c..fafdf57f70 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc @@ -69,8 +69,8 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic1) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestRandomDatasetGetDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetGetDatasetSize."; +TEST_F(MindDataTestPipeline, TestRandomDatasetGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetGetters."; // Create a RandomDataset std::shared_ptr schema = Schema(); @@ -79,7 +79,9 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetGetDatasetSize) { std::shared_ptr ds = RandomData(50, schema); EXPECT_NE(ds, nullptr); + std::vector column_names = {"image", "label"}; EXPECT_EQ(ds->GetDatasetSize(), 50); + EXPECT_EQ(ds->GetColumnNames(), column_names); } TEST_F(MindDataTestPipeline, TestRandomDatasetBasic2) { diff --git a/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc b/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc index 1735274fba..73ebc7bdae 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc @@ -82,8 +82,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetBasic) { GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); } -TEST_F(MindDataTestPipeline, TestTextFileGetDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileGetDatasetSize."; +TEST_F(MindDataTestPipeline, TestTextFileGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileGetters."; // Test TextFile Dataset with single text file and many default inputs // Set configuration @@ -101,7 +101,9 @@ TEST_F(MindDataTestPipeline, TestTextFileGetDatasetSize) { std::shared_ptr ds = TextFile({tf_file1}, 2); EXPECT_NE(ds, nullptr); + std::vector column_names = {"text"}; EXPECT_EQ(ds->GetDatasetSize(), 2); + EXPECT_EQ(ds->GetColumnNames(), column_names); // Restore configuration GlobalContext::config_manager()->set_seed(original_seed); diff --git a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc index 59841f6515..8b89905a93 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc @@ -84,8 +84,8 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasic) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasicGetDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetBasicGetDatasetSize."; +TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasicGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetBasicGetters."; // Create a TFRecord Dataset std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; @@ -112,6 +112,8 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasicGetDatasetSize) { EXPECT_NE(ds, nullptr); EXPECT_EQ(ds->GetDatasetSize(), 6); + std::vector column_names = {"image"}; + EXPECT_EQ(ds->GetColumnNames(), column_names); } TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle) { diff --git a/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc b/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc index 2a44d04bcc..bef5e1006c 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc @@ -94,8 +94,8 @@ TEST_F(MindDataTestPipeline, TestVOCGetClassIndex) { EXPECT_EQ(class_index1[2].second[0], 9); } -TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetDatasetSize."; +TEST_F(MindDataTestPipeline, TestVOCGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetters."; // Create a VOC Dataset std::string folder_path = datasets_root_path_ + "/testVOC2012_2"; @@ -111,6 +111,8 @@ TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) { ds = ds->Repeat(2); EXPECT_EQ(ds->GetDatasetSize(), 6); + std::vector column_names = {"image", "bbox", "label", "difficult", "truncate"}; + EXPECT_EQ(ds->GetColumnNames(), column_names); } TEST_F(MindDataTestPipeline, TestVOCDetection) {