From 449e1526dc486b3b27ff07d38249f35b84ec3753 Mon Sep 17 00:00:00 2001 From: Mahdi Date: Tue, 17 Nov 2020 15:22:58 -0500 Subject: [PATCH] Fixed GetDatasetSize for TextFile --- .../minddata/dataset/engine/datasetops/source/text_file_op.cc | 2 +- tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index 935150c361..797383f131 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -529,7 +529,7 @@ Status TextFileOp::GetDatasetSize(int64_t *dataset_size) { int64_t num_rows, sample_size; sample_size = total_rows_; if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - num_rows = total_rows_; + num_rows = num_rows_per_shard_; *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc b/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc index 73ebc7bdae..1dbc1f1b13 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc @@ -105,6 +105,10 @@ TEST_F(MindDataTestPipeline, TestTextFileGetters) { EXPECT_EQ(ds->GetDatasetSize(), 2); EXPECT_EQ(ds->GetColumnNames(), column_names); + ds = TextFile({tf_file1}, 0); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 3); // Restore configuration GlobalContext::config_manager()->set_seed(original_seed); GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);