Browse Source

Fixed GetDatasetSize for TextFile

tags/v1.1.0
Mahdi 5 years ago
parent
commit
449e1526dc
2 changed files with 5 additions and 1 deletions
  1. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc
  2. +4
    -0
      tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc

+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc View File

@@ -529,7 +529,7 @@ Status TextFileOp::GetDatasetSize(int64_t *dataset_size) {
int64_t num_rows, sample_size;
sample_size = total_rows_;
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
num_rows = total_rows_;
num_rows = num_rows_per_shard_;
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();


+ 4
- 0
tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc View File

@@ -105,6 +105,10 @@ TEST_F(MindDataTestPipeline, TestTextFileGetters) {
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);

ds = TextFile({tf_file1}, 0);
EXPECT_NE(ds, nullptr);

EXPECT_EQ(ds->GetDatasetSize(), 3);
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);


Loading…
Cancel
Save