Merge pull request !5187 from cathwong/ckw_c_api_fixes2tags/v1.0.0
| @@ -1009,9 +1009,14 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() { | |||
| } | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| // Sort the dataset files in a lexicographical order | |||
| std::vector<std::string> sorted_dataset_files = dataset_files_; | |||
| std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); | |||
| std::shared_ptr<ClueOp> clue_op = | |||
| std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, | |||
| dataset_files_, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||
| sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||
| RETURN_EMPTY_IF_ERROR(clue_op->Init()); | |||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp | |||
| @@ -1019,10 +1024,10 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() { | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset | |||
| RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(dataset_files_, &num_rows)); | |||
| RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows)); | |||
| // Add the shuffle op after this op | |||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| node_ops.push_back(shuffle_op); | |||
| } | |||
| @@ -1162,6 +1167,11 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() { | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| // Sort the dataset files in a lexicographical order | |||
| std::vector<std::string> sorted_dataset_files = dataset_files_; | |||
| std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); | |||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list; | |||
| for (auto v : column_defaults_) { | |||
| if (v->type == CsvType::INT) { | |||
| @@ -1177,8 +1187,8 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() { | |||
| } | |||
| std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | |||
| dataset_files_, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, num_samples_, | |||
| worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||
| sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, | |||
| num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||
| RETURN_EMPTY_IF_ERROR(csv_op->Init()); | |||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp | |||
| @@ -1186,10 +1196,10 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() { | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset | |||
| RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(dataset_files_, column_names_.empty(), &num_rows)); | |||
| RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows)); | |||
| // Add the shuffle op after this op | |||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| node_ops.push_back(shuffle_op); | |||
| } | |||
| @@ -1398,6 +1408,10 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| // Sort the dataset files in a lexicographical order | |||
| std::vector<std::string> sorted_dataset_files = dataset_files_; | |||
| std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); | |||
| // Do internal Schema generation. | |||
| auto schema = std::make_unique<DataSchema>(); | |||
| RETURN_EMPTY_IF_ERROR( | |||
| @@ -1405,7 +1419,7 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() { | |||
| // Create and initalize TextFileOp | |||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), dataset_files_, | |||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | |||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr)); | |||
| RETURN_EMPTY_IF_ERROR(text_file_op->Init()); | |||
| @@ -1415,10 +1429,10 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() { | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset | |||
| RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(dataset_files_, &num_rows)); | |||
| RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows)); | |||
| // Add the shuffle op after this op | |||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| node_ops.push_back(shuffle_op); | |||
| } | |||
| @@ -362,8 +362,8 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetIFLYTEK) { | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFiles) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEDatasetShuffleFiles."; | |||
| TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFilesA) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEDatasetShuffleFilesA."; | |||
| // Test CLUE Dataset with files shuffle, num_parallel_workers=1 | |||
| // Set configuration | |||
| @@ -373,7 +373,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFiles) { | |||
| GlobalContext::config_manager()->set_seed(135); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||
| // Create a CLUE Dataset, with two text files | |||
| // Create a CLUE Dataset, with two text files, dev.json and train.json, in lexicographical order | |||
| // Note: train.json has 3 rows | |||
| // Note: dev.json has 3 rows | |||
| // Use default of all samples | |||
| @@ -383,7 +383,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFiles) { | |||
| std::string clue_file2 = datasets_root_path_ + "/testCLUE/afqmc/dev.json"; | |||
| std::string task = "AFQMC"; | |||
| std::string usage = "train"; | |||
| std::shared_ptr<Dataset> ds = CLUE({clue_file1, clue_file2}, task, usage, 0, ShuffleMode::kFiles); | |||
| std::shared_ptr<Dataset> ds = CLUE({clue_file2, clue_file1}, task, usage, 0, ShuffleMode::kFiles); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| @@ -397,12 +397,79 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFiles) { | |||
| EXPECT_NE(row.find("sentence1"), row.end()); | |||
| std::vector<std::string> expected_result = { | |||
| "你有花呗吗", | |||
| "吃饭能用花呗吗", | |||
| "蚂蚁花呗支付金额有什么限制", | |||
| "蚂蚁借呗等额还款能否换成先息后本", | |||
| "蚂蚁花呗说我违约了", | |||
| "帮我看看本月花呗账单结清了没", | |||
| "帮我看看本月花呗账单结清了没" | |||
| }; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto text = row["sentence1"]; | |||
| std::string_view sv; | |||
| text->GetItemAt(&sv, {0}); | |||
| std::string ss(sv); | |||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||
| // Compare against expected result | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||
| i++; | |||
| iter->GetNextRow(&row); | |||
| } | |||
| // Expect 3 + 3 = 6 samples | |||
| EXPECT_EQ(i, 6); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| // Restore configuration | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFilesB) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEDatasetShuffleFilesB."; | |||
| // Test CLUE Dataset with files shuffle, num_parallel_workers=1 | |||
| // Set configuration | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(135); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||
| // Create a CLUE Dataset, with two text files, train.json and dev.json, in non-lexicographical order | |||
| // Note: train.json has 3 rows | |||
| // Note: dev.json has 3 rows | |||
| // Use default of all samples | |||
| // They have the same keywords | |||
| // Set shuffle to files shuffle | |||
| std::string clue_file1 = datasets_root_path_ + "/testCLUE/afqmc/train.json"; | |||
| std::string clue_file2 = datasets_root_path_ + "/testCLUE/afqmc/dev.json"; | |||
| std::string task = "AFQMC"; | |||
| std::string usage = "train"; | |||
| std::shared_ptr<Dataset> ds = CLUE({clue_file1, clue_file2}, task, usage, 0, ShuffleMode::kFiles); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| EXPECT_NE(row.find("sentence1"), row.end()); | |||
| std::vector<std::string> expected_result = { | |||
| "你有花呗吗", | |||
| "吃饭能用花呗吗", | |||
| "蚂蚁花呗支付金额有什么限制" | |||
| "蚂蚁花呗支付金额有什么限制", | |||
| "蚂蚁借呗等额还款能否换成先息后本", | |||
| "蚂蚁花呗说我违约了", | |||
| "帮我看看本月花呗账单结清了没" | |||
| }; | |||
| uint64_t i = 0; | |||
| @@ -359,8 +359,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetException) { | |||
| EXPECT_EQ(ds5, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFiles) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVDatasetShuffleFiles."; | |||
| TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesA) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVDatasetShuffleFilesA."; | |||
| // Set configuration | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| @@ -369,7 +369,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFiles) { | |||
| GlobalContext::config_manager()->set_seed(130); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||
| // Create a CSVDataset, with single CSV file | |||
| // Create a CSVDataset, with 2 CSV files, 1.csv and append.csv in lexicographical order | |||
| std::string file1 = datasets_root_path_ + "/testCSV/1.csv"; | |||
| std::string file2 = datasets_root_path_ + "/testCSV/append.csv"; | |||
| std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"}; | |||
| @@ -418,6 +418,66 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFiles) { | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesB) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVDatasetShuffleFilesB."; | |||
| // Set configuration | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(130); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||
| // Create a CSVDataset, with 2 CSV files, append.csv and 1.csv in non-lexicographical order | |||
| std::string file1 = datasets_root_path_ + "/testCSV/1.csv"; | |||
| std::string file2 = datasets_root_path_ + "/testCSV/append.csv"; | |||
| std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"}; | |||
| std::shared_ptr<Dataset> ds = CSV({file2, file1}, ',', {}, column_names, -1, ShuffleMode::kFiles); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| EXPECT_NE(row.find("col1"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"13", "14", "15", "16"}, | |||
| {"1", "2", "3", "4"}, | |||
| {"17", "18", "19", "20"}, | |||
| {"5", "6", "7", "8"}, | |||
| {"21", "22", "23", "24"}, | |||
| {"9", "10", "11", "12"}, | |||
| }; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::string_view sv; | |||
| text->GetItemAt(&sv, {0}); | |||
| std::string ss(sv); | |||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| iter->GetNextRow(&row); | |||
| i++; | |||
| } | |||
| // Expect 6 samples | |||
| EXPECT_EQ(i, 6); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| // Restore configuration | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleGlobal) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVDatasetShuffleGlobal."; | |||
| // Test CSV Dataset with GLOBLE shuffle | |||
| @@ -165,8 +165,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetFail7) { | |||
| EXPECT_EQ(ds, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1."; | |||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1A) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1A."; | |||
| // Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=1 | |||
| // Set configuration | |||
| @@ -176,7 +176,7 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1) { | |||
| GlobalContext::config_manager()->set_seed(654); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||
| // Create a TextFile Dataset, with two text files | |||
| // Create a TextFile Dataset, with two text files, 1.txt then 2.txt, in lexicographical order. | |||
| // Note: 1.txt has 3 rows | |||
| // Note: 2.txt has 2 rows | |||
| // Use default of all samples | |||
| @@ -223,6 +223,64 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1) { | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1B) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1B."; | |||
| // Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=1 | |||
| // Set configuration | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(654); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||
| // Create a TextFile Dataset, with two text files, 2.txt then 1.txt, in non-lexicographical order | |||
| // Note: 1.txt has 3 rows | |||
| // Note: 2.txt has 2 rows | |||
| // Use default of all samples | |||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||
| std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | |||
| std::shared_ptr<Dataset> ds = TextFile({tf_file2, tf_file1}, 0, ShuffleMode::kFalse); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::string> expected_result = {"This is a text file.", "Be happy every day.", "Good luck to everyone.", | |||
| "Another file.", "End of file."}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto text = row["text"]; | |||
| MS_LOG(INFO) << "Tensor text shape: " << text->shape(); | |||
| std::string_view sv; | |||
| text->GetItemAt(&sv, {0}); | |||
| std::string ss(sv); | |||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||
| // Compare against expected result | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||
| i++; | |||
| iter->GetNextRow(&row); | |||
| } | |||
| // Expect 2 + 3 = 5 samples | |||
| EXPECT_EQ(i, 5); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| // Restore configuration | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse4Shard) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse4Shard."; | |||
| // Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=4, shard coverage | |||
| @@ -280,8 +338,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse4Shard) { | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1."; | |||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1A) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1A."; | |||
| // Test TextFile Dataset with files shuffle, num_parallel_workers=1 | |||
| // Set configuration | |||
| @@ -291,7 +349,7 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1) { | |||
| GlobalContext::config_manager()->set_seed(135); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||
| // Create a TextFile Dataset, with two text files | |||
| // Create a TextFile Dataset, with two text files, 1.txt then 2.txt, in lexicographical order. | |||
| // Note: 1.txt has 3 rows | |||
| // Note: 2.txt has 2 rows | |||
| // Use default of all samples | |||
| @@ -340,6 +398,66 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1) { | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1B) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1B."; | |||
| // Test TextFile Dataset with files shuffle, num_parallel_workers=1 | |||
| // Set configuration | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(135); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||
| // Create a TextFile Dataset, with two text files, 2.txt then 1.txt, in non-lexicographical order. | |||
| // Note: 1.txt has 3 rows | |||
| // Note: 2.txt has 2 rows | |||
| // Use default of all samples | |||
| // Set shuffle to files shuffle | |||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||
| std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | |||
| std::shared_ptr<Dataset> ds = TextFile({tf_file2, tf_file1}, 0, ShuffleMode::kFiles); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::string> expected_result = { | |||
| "This is a text file.", "Be happy every day.", "Good luck to everyone.", "Another file.", "End of file.", | |||
| }; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto text = row["text"]; | |||
| MS_LOG(INFO) << "Tensor text shape: " << text->shape(); | |||
| std::string_view sv; | |||
| text->GetItemAt(&sv, {0}); | |||
| std::string ss(sv); | |||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||
| // Compare against expected result | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||
| i++; | |||
| iter->GetNextRow(&row); | |||
| } | |||
| // Expect 2 + 3 = 5 samples | |||
| EXPECT_EQ(i, 5); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| // Restore configuration | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles4) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles4."; | |||
| // Test TextFile Dataset with files shuffle, num_parallel_workers=4 | |||