Browse Source

!8055 Fixed the problem of Repeat with GetDatasetSize

Merge pull request !8055 from MahdiRahmaniHanzaki/tree-getters
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
be2a3ebf6d
4 changed files with 9 additions and 5 deletions
  1. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc
  2. +4
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc
  3. +1
    -3
      tests/ut/cpp/dataset/c_api_dataset_album_test.cc
  4. +3
    -0
      tests/ut/cpp/dataset/c_api_dataset_voc_test.cc

+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc View File

@@ -194,7 +194,7 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) {

// Get Dataset size
Status RepeatOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0 || num_repeats_ == -1) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}


+ 4
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc View File

@@ -481,10 +481,13 @@ Status ManifestOp::GetNumClasses(int64_t *num_classes) {
*num_classes = num_classes_;
return Status::OK();
}
int64_t classes_count;
std::shared_ptr<ManifestOp> op;
RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op));
RETURN_IF_NOT_OK(op->ParseManifestFile());
*num_classes = num_classes_;
classes_count = static_cast<int64_t>(op->label_index_.size());
*num_classes = classes_count;
num_classes_ = classes_count;
return Status::OK();
}



+ 1
- 3
tests/ut/cpp/dataset/c_api_dataset_album_test.cc View File

@@ -79,8 +79,6 @@ TEST_F(MindDataTestPipeline, TestAlbumgetters) {
std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names);
EXPECT_NE(ds, nullptr);

int64_t dataset_size = ds->GetDatasetSize();
EXPECT_EQ(dataset_size, 7);
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(num_classes, -1);
int64_t batch_size = ds->GetBatchSize();
@@ -114,7 +112,7 @@ TEST_F(MindDataTestPipeline, TestAlbumDecode) {
auto shape = image->shape();
MS_LOG(INFO) << "Tensor image shape size: " << shape.Size();
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
EXPECT_GT(shape.Size(), 1); // Verify decode=true took effect
EXPECT_GT(shape.Size(), 1); // Verify decode=true took effect
iter->GetNextRow(&row);
}



+ 3
- 0
tests/ut/cpp/dataset/c_api_dataset_voc_test.cc View File

@@ -99,6 +99,9 @@ TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) {
std::shared_ptr<Dataset> ds = VOC(folder_path, "Detection", "train", class_index, false, SequentialSampler(0, 6));
EXPECT_NE(ds, nullptr);

ds = ds->Batch(2);
ds = ds->Repeat(2);

EXPECT_EQ(ds->GetDatasetSize(), 6);
}



Loading…
Cancel
Save