Browse Source

Added GetColumnNames to C++

tags/v1.1.0
Mahdi 5 years ago
parent
commit
5a7515e48f
15 changed files with 98 additions and 20 deletions
  1. +19
    -0
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +16
    -0
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc
  3. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h
  4. +4
    -0
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  5. +3
    -2
      tests/ut/cpp/dataset/c_api_dataset_album_test.cc
  6. +21
    -2
      tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc
  7. +4
    -2
      tests/ut/cpp/dataset/c_api_dataset_clue_test.cc
  8. +4
    -2
      tests/ut/cpp/dataset/c_api_dataset_coco_test.cc
  9. +3
    -2
      tests/ut/cpp/dataset/c_api_dataset_csv_test.cc
  10. +2
    -0
      tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc
  11. +5
    -2
      tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc
  12. +4
    -2
      tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc
  13. +4
    -2
      tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc
  14. +4
    -2
      tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc
  15. +4
    -2
      tests/ut/cpp/dataset/c_api_dataset_voc_test.cc

+ 19
- 0
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -274,6 +274,25 @@ int64_t Dataset::GetNumClasses() {
return rc.IsError() ? -1 : num_classes; return rc.IsError() ? -1 : num_classes;
} }


std::vector<std::string> Dataset::GetColumnNames() {
std::vector<std::string> col_names;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetColumnNames: Initializing RuntimeContext failed.";
return std::vector<std::string>();
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetColumnNames: Initializing TreeGetters failed.";
return std::vector<std::string>();
}
rc = tree_getters_->GetColumnNames(&col_names);
return rc.IsError() ? std::vector<std::string>() : col_names;
}

std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() { std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() {
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing; std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
auto ds = shared_from_this(); auto ds = shared_from_this();


+ 16
- 0
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc View File

@@ -466,6 +466,22 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) {
return Status::OK(); return Status::OK();
} }


Status TreeGetters::GetColumnNames(std::vector<std::string> *output) {
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map();
if (column_name_id_map.empty()) RETURN_STATUS_UNEXPECTED("GetColumnNames: column_name_id map was empty.");
std::vector<std::pair<std::string, int32_t>> column_name_id_vector(column_name_id_map.begin(),
column_name_id_map.end());
std::sort(column_name_id_vector.begin(), column_name_id_vector.end(),
[](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
return a.second < b.second;
});
for (auto item : column_name_id_vector) {
(*output).push_back(item.first);
}
return Status::OK();
}

Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) { Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h View File

@@ -166,6 +166,7 @@ class TreeGetters : public TreeConsumer {
Status GetBatchSize(int64_t *batch_size); Status GetBatchSize(int64_t *batch_size);
Status GetRepeatCount(int64_t *repeat_count); Status GetRepeatCount(int64_t *repeat_count);
Status GetNumClasses(int64_t *num_classes); Status GetNumClasses(int64_t *num_classes);
Status GetColumnNames(std::vector<std::string> *output);
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
bool isInitialized(); bool isInitialized();
std::string Name() override { return "TreeGetters"; } std::string Name() override { return "TreeGetters"; }


+ 4
- 0
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -119,6 +119,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return number of classes. If failed, return -1 /// \return number of classes. If failed, return -1
int64_t GetNumClasses(); int64_t GetNumClasses();


/// \brief Gets the column names
/// \return Names of the columns. If failed, return an empty vector
std::vector<std::string> GetColumnNames();

/// \brief Gets the class indexing /// \brief Gets the class indexing
/// \return a map of ClassIndexing. If failed, return an empty map /// \return a map of ClassIndexing. If failed, return an empty map
std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing(); std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing();


+ 3
- 2
tests/ut/cpp/dataset/c_api_dataset_album_test.cc View File

@@ -56,8 +56,8 @@ TEST_F(MindDataTestPipeline, TestAlbumBasic) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestAlbumgetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumgetters.";
TEST_F(MindDataTestPipeline, TestAlbumGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumGetters.";


std::string folder_path = datasets_root_path_ + "/testAlbum/images"; std::string folder_path = datasets_root_path_ + "/testAlbum/images";
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json"; std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
@@ -72,6 +72,7 @@ TEST_F(MindDataTestPipeline, TestAlbumgetters) {
EXPECT_EQ(batch_size, 1); EXPECT_EQ(batch_size, 1);
int64_t repeat_count = ds->GetRepeatCount(); int64_t repeat_count = ds->GetRepeatCount();
EXPECT_EQ(repeat_count, 1); EXPECT_EQ(repeat_count, 1);
EXPECT_EQ(ds->GetColumnNames(), column_names);
} }


TEST_F(MindDataTestPipeline, TestAlbumDecode) { TEST_F(MindDataTestPipeline, TestAlbumDecode) {


+ 21
- 2
tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc View File

@@ -81,6 +81,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Getters) {
EXPECT_EQ(ds->GetDatasetSize(), 10000); EXPECT_EQ(ds->GetDatasetSize(), 10000);
std::vector<DataType> types = ds->GetOutputTypes(); std::vector<DataType> types = ds->GetOutputTypes();
std::vector<TensorShape> shapes = ds->GetOutputShapes(); std::vector<TensorShape> shapes = ds->GetOutputShapes();
std::vector<std::string> column_names = {"image", "label"};
int64_t num_classes = ds->GetNumClasses(); int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 2); EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8"); EXPECT_EQ(types[0].ToString(), "uint8");
@@ -97,6 +98,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Getters) {
EXPECT_EQ(ds->GetOutputShapes(), shapes); EXPECT_EQ(ds->GetOutputShapes(), shapes);
EXPECT_EQ(ds->GetNumClasses(), -1); EXPECT_EQ(ds->GetNumClasses(), -1);


EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 10000); EXPECT_EQ(ds->GetDatasetSize(), 10000);
EXPECT_EQ(ds->GetOutputTypes(), types); EXPECT_EQ(ds->GetOutputTypes(), types);
EXPECT_EQ(ds->GetOutputShapes(), shapes); EXPECT_EQ(ds->GetOutputShapes(), shapes);
@@ -141,15 +143,32 @@ TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestCifar100GetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100GetDatasetSize.";
TEST_F(MindDataTestPipeline, TestCifar100Getters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Getters.";


// Create a Cifar100 Dataset // Create a Cifar100 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
std::shared_ptr<Dataset> ds = Cifar100(folder_path, "all", RandomSampler(false, 10)); std::shared_ptr<Dataset> ds = Cifar100(folder_path, "all", RandomSampler(false, 10));
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


std::vector<std::string> column_names = {"image", "coarse_label", "fine_label"};
std::vector<DataType> types = ds->GetOutputTypes();
std::vector<TensorShape> shapes = ds->GetOutputShapes();
int64_t num_classes = ds->GetNumClasses();

EXPECT_EQ(types.size(), 3);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(types[2].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 3);
EXPECT_EQ(shapes[0].ToString(), "<32,32,3>");
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(shapes[2].ToString(), "<>");
EXPECT_EQ(num_classes, -1);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 10); EXPECT_EQ(ds->GetDatasetSize(), 10);
EXPECT_EQ(ds->GetColumnNames(), column_names);
} }


TEST_F(MindDataTestPipeline, TestCifar100DatasetFail) { TEST_F(MindDataTestPipeline, TestCifar100DatasetFail) {


+ 4
- 2
tests/ut/cpp/dataset/c_api_dataset_clue_test.cc View File

@@ -147,17 +147,19 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetBasic) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestCLUEGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEGetDatasetSize.";
TEST_F(MindDataTestPipeline, TestCLUEGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEGetters.";


// Create a CLUEFile Dataset, with single CLUE file // Create a CLUEFile Dataset, with single CLUE file
std::string clue_file = datasets_root_path_ + "/testCLUE/afqmc/train.json"; std::string clue_file = datasets_root_path_ + "/testCLUE/afqmc/train.json";
std::string task = "AFQMC"; std::string task = "AFQMC";
std::string usage = "train"; std::string usage = "train";
std::shared_ptr<Dataset> ds = CLUE({clue_file}, task, usage, 2); std::shared_ptr<Dataset> ds = CLUE({clue_file}, task, usage, 2);
std::vector<std::string> column_names = {"label", "sentence1", "sentence2"};
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


EXPECT_EQ(ds->GetDatasetSize(), 2); EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);
} }


TEST_F(MindDataTestPipeline, TestCLUEDatasetCMNLI) { TEST_F(MindDataTestPipeline, TestCLUEDatasetCMNLI) {


+ 4
- 2
tests/ut/cpp/dataset/c_api_dataset_coco_test.cc View File

@@ -61,8 +61,8 @@ TEST_F(MindDataTestPipeline, TestCocoDefault) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestCocoGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoGetDatasetSize.";
TEST_F(MindDataTestPipeline, TestCocoGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoGetters.";
// Create a Coco Dataset // Create a Coco Dataset
std::string folder_path = datasets_root_path_ + "/testCOCO/train"; std::string folder_path = datasets_root_path_ + "/testCOCO/train";
std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/train.json"; std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/train.json";
@@ -70,7 +70,9 @@ TEST_F(MindDataTestPipeline, TestCocoGetDatasetSize) {
std::shared_ptr<Dataset> ds = Coco(folder_path, annotation_file); std::shared_ptr<Dataset> ds = Coco(folder_path, annotation_file);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


std::vector<std::string> column_names = {"image", "bbox", "category_id", "iscrowd"};
EXPECT_EQ(ds->GetDatasetSize(), 6); EXPECT_EQ(ds->GetDatasetSize(), 6);
EXPECT_EQ(ds->GetColumnNames(), column_names);
} }


TEST_F(MindDataTestPipeline, TestCocoDetection) { TEST_F(MindDataTestPipeline, TestCocoDetection) {


+ 3
- 2
tests/ut/cpp/dataset/c_api_dataset_csv_test.cc View File

@@ -70,8 +70,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetBasic) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestCSVGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVGetDatasetSize.";
TEST_F(MindDataTestPipeline, TestCSVGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVGetters.";


// Create a CSVDataset, with single CSV file // Create a CSVDataset, with single CSV file
std::string train_file = datasets_root_path_ + "/testCSV/1.csv"; std::string train_file = datasets_root_path_ + "/testCSV/1.csv";
@@ -80,6 +80,7 @@ TEST_F(MindDataTestPipeline, TestCSVGetDatasetSize) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


EXPECT_EQ(ds->GetDatasetSize(), 3); EXPECT_EQ(ds->GetDatasetSize(), 3);
EXPECT_EQ(ds->GetColumnNames(), column_names);
} }


TEST_F(MindDataTestPipeline, TestCSVDatasetMultiFiles) { TEST_F(MindDataTestPipeline, TestCSVDatasetMultiFiles) {


+ 2
- 0
tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc View File

@@ -62,10 +62,12 @@ TEST_F(MindDataTestPipeline, TestManifestGetters) {
// Create a Manifest Dataset // Create a Manifest Dataset
std::shared_ptr<Dataset> ds1 = Manifest(file_path1); std::shared_ptr<Dataset> ds1 = Manifest(file_path1);
std::shared_ptr<Dataset> ds2 = Manifest(file_path2); std::shared_ptr<Dataset> ds2 = Manifest(file_path2);
std::vector<std::string> column_names = {"image", "label"};


EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
EXPECT_EQ(ds1->GetDatasetSize(), 2); EXPECT_EQ(ds1->GetDatasetSize(), 2);
EXPECT_EQ(ds1->GetNumClasses(), 2); EXPECT_EQ(ds1->GetNumClasses(), 2);
EXPECT_EQ(ds1->GetColumnNames(), column_names);


EXPECT_NE(ds2, nullptr); EXPECT_NE(ds2, nullptr);
EXPECT_EQ(ds2->GetDatasetSize(), 4); EXPECT_EQ(ds2->GetDatasetSize(), 4);


+ 5
- 2
tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc View File

@@ -57,8 +57,8 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess1) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestMindDataGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataGetDatasetSize with string file pattern.";
TEST_F(MindDataTestPipeline, TestMindDataGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataGetters with string file pattern.";


// Create a MindData Dataset // Create a MindData Dataset
// Pass one mindrecord shard file to parse dataset info, and search for other mindrecord files with same dataset info, // Pass one mindrecord shard file to parse dataset info, and search for other mindrecord files with same dataset info,
@@ -67,7 +67,10 @@ TEST_F(MindDataTestPipeline, TestMindDataGetDatasetSize) {
std::shared_ptr<Dataset> ds = MindData(file_path); std::shared_ptr<Dataset> ds = MindData(file_path);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


std::vector<std::string> column_names = {"data", "file_name", "label"};

EXPECT_EQ(ds->GetDatasetSize(), 20); EXPECT_EQ(ds->GetDatasetSize(), 20);
EXPECT_EQ(ds->GetColumnNames(), column_names);
} }


TEST_F(MindDataTestPipeline, TestMindDataSuccess2) { TEST_F(MindDataTestPipeline, TestMindDataSuccess2) {


+ 4
- 2
tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc View File

@@ -69,8 +69,8 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic1) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestRandomDatasetGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetGetDatasetSize.";
TEST_F(MindDataTestPipeline, TestRandomDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetGetters.";


// Create a RandomDataset // Create a RandomDataset
std::shared_ptr<SchemaObj> schema = Schema(); std::shared_ptr<SchemaObj> schema = Schema();
@@ -79,7 +79,9 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetGetDatasetSize) {
std::shared_ptr<Dataset> ds = RandomData(50, schema); std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


std::vector<std::string> column_names = {"image", "label"};
EXPECT_EQ(ds->GetDatasetSize(), 50); EXPECT_EQ(ds->GetDatasetSize(), 50);
EXPECT_EQ(ds->GetColumnNames(), column_names);
} }


TEST_F(MindDataTestPipeline, TestRandomDatasetBasic2) { TEST_F(MindDataTestPipeline, TestRandomDatasetBasic2) {


+ 4
- 2
tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc View File

@@ -82,8 +82,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetBasic) {
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
} }


TEST_F(MindDataTestPipeline, TestTextFileGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileGetDatasetSize.";
TEST_F(MindDataTestPipeline, TestTextFileGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileGetters.";
// Test TextFile Dataset with single text file and many default inputs // Test TextFile Dataset with single text file and many default inputs


// Set configuration // Set configuration
@@ -101,7 +101,9 @@ TEST_F(MindDataTestPipeline, TestTextFileGetDatasetSize) {
std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 2); std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 2);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


std::vector<std::string> column_names = {"text"};
EXPECT_EQ(ds->GetDatasetSize(), 2); EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);


// Restore configuration // Restore configuration
GlobalContext::config_manager()->set_seed(original_seed); GlobalContext::config_manager()->set_seed(original_seed);


+ 4
- 2
tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc View File

@@ -84,8 +84,8 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasic) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasicGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetBasicGetDatasetSize.";
TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasicGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetBasicGetters.";


// Create a TFRecord Dataset // Create a TFRecord Dataset
std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
@@ -112,6 +112,8 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasicGetDatasetSize) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


EXPECT_EQ(ds->GetDatasetSize(), 6); EXPECT_EQ(ds->GetDatasetSize(), 6);
std::vector<std::string> column_names = {"image"};
EXPECT_EQ(ds->GetColumnNames(), column_names);
} }


TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle) { TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle) {


+ 4
- 2
tests/ut/cpp/dataset/c_api_dataset_voc_test.cc View File

@@ -94,8 +94,8 @@ TEST_F(MindDataTestPipeline, TestVOCGetClassIndex) {
EXPECT_EQ(class_index1[2].second[0], 9); EXPECT_EQ(class_index1[2].second[0], 9);
} }


TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetDatasetSize.";
TEST_F(MindDataTestPipeline, TestVOCGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetters.";


// Create a VOC Dataset // Create a VOC Dataset
std::string folder_path = datasets_root_path_ + "/testVOC2012_2"; std::string folder_path = datasets_root_path_ + "/testVOC2012_2";
@@ -111,6 +111,8 @@ TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) {
ds = ds->Repeat(2); ds = ds->Repeat(2);


EXPECT_EQ(ds->GetDatasetSize(), 6); EXPECT_EQ(ds->GetDatasetSize(), 6);
std::vector<std::string> column_names = {"image", "bbox", "label", "difficult", "truncate"};
EXPECT_EQ(ds->GetColumnNames(), column_names);
} }


TEST_F(MindDataTestPipeline, TestVOCDetection) { TEST_F(MindDataTestPipeline, TestVOCDetection) {


Loading…
Cancel
Save