Merge pull request !621 from liyong126/mindrecord_uttags/v0.3.0-alpha
| @@ -346,7 +346,8 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string | |||||
| MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; | MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; | ||||
| return; | return; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Get" << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index."; | |||||
| MS_LOG(INFO) << "Get " << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index."; | |||||
| std::lock_guard<std::mutex> lck(shard_locker_); | |||||
| for (int i = 0; i < static_cast<int>(columns.size()); ++i) { | for (int i = 0; i < static_cast<int>(columns.size()); ++i) { | ||||
| categories.emplace(columns[i][0]); | categories.emplace(columns[i][0]); | ||||
| } | } | ||||
| @@ -16,9 +16,9 @@ | |||||
| #include "ut_common.h" | #include "ut_common.h" | ||||
| using mindspore::MsLogLevel::ERROR; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | using mindspore::LogStream; | ||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::ERROR; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| @@ -33,23 +33,6 @@ void Common::SetUp() {} | |||||
| void Common::TearDown() {} | void Common::TearDown() {} | ||||
| void Common::LoadData(const std::string &directory, std::vector<json> &json_buffer, const int max_num) { | |||||
| int count = 0; | |||||
| string input_path = directory; | |||||
| ifstream infile(input_path); | |||||
| if (!infile.is_open()) { | |||||
| MS_LOG(ERROR) << "can not open the file "; | |||||
| return; | |||||
| } | |||||
| string temp; | |||||
| while (getline(infile, temp) && count != max_num) { | |||||
| count++; | |||||
| json j = json::parse(temp); | |||||
| json_buffer.push_back(j); | |||||
| } | |||||
| infile.close(); | |||||
| } | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| #if __cplusplus | #if __cplusplus | ||||
| } | } | ||||
| @@ -70,5 +53,353 @@ const std::string FormatInfo(const std::string &message, uint32_t message_total_ | |||||
| std::string right_padding(static_cast<uint64_t>(floor(padding_length / 2.0)), '='); | std::string right_padding(static_cast<uint64_t>(floor(padding_length / 2.0)), '='); | ||||
| return left_padding + part_message + right_padding; | return left_padding + part_message + right_padding; | ||||
| } | } | ||||
| void LoadData(const std::string &directory, std::vector<json> &json_buffer, const int max_num) { | |||||
| int count = 0; | |||||
| string input_path = directory; | |||||
| ifstream infile(input_path); | |||||
| if (!infile.is_open()) { | |||||
| MS_LOG(ERROR) << "can not open the file "; | |||||
| return; | |||||
| } | |||||
| string temp; | |||||
| while (getline(infile, temp) && count != max_num) { | |||||
| count++; | |||||
| json j = json::parse(temp); | |||||
| json_buffer.push_back(j); | |||||
| } | |||||
| infile.close(); | |||||
| } | |||||
| void LoadDataFromImageNet(const std::string &directory, std::vector<json> &json_buffer, const int max_num) { | |||||
| int count = 0; | |||||
| string input_path = directory; | |||||
| ifstream infile(input_path); | |||||
| if (!infile.is_open()) { | |||||
| MS_LOG(ERROR) << "can not open the file "; | |||||
| return; | |||||
| } | |||||
| string temp; | |||||
| string filename; | |||||
| string label; | |||||
| json j; | |||||
| while (getline(infile, temp) && count != max_num) { | |||||
| count++; | |||||
| std::size_t pos = temp.find(",", 0); | |||||
| if (pos != std::string::npos) { | |||||
| j["file_name"] = temp.substr(0, pos); | |||||
| j["label"] = atoi(common::SafeCStr(temp.substr(pos + 1, temp.length()))); | |||||
| json_buffer.push_back(j); | |||||
| } | |||||
| } | |||||
| infile.close(); | |||||
| } | |||||
| int Img2DataUint8(const std::vector<std::string> &img_absolute_path, std::vector<std::vector<uint8_t>> &bin_data) { | |||||
| for (auto &file : img_absolute_path) { | |||||
| // read image file | |||||
| std::ifstream in(common::SafeCStr(file), std::ios::in | std::ios::binary | std::ios::ate); | |||||
| if (!in) { | |||||
| MS_LOG(ERROR) << common::SafeCStr(file) << " is not a directory or not exist!"; | |||||
| return -1; | |||||
| } | |||||
| // get the file size | |||||
| uint64_t size = in.tellg(); | |||||
| in.seekg(0, std::ios::beg); | |||||
| std::vector<uint8_t> file_data(size); | |||||
| in.read(reinterpret_cast<char *>(&file_data[0]), size); | |||||
| in.close(); | |||||
| bin_data.push_back(file_data); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| int GetAbsoluteFiles(std::string directory, std::vector<std::string> &files_absolute_path) { | |||||
| DIR *dir = opendir(common::SafeCStr(directory)); | |||||
| if (dir == nullptr) { | |||||
| MS_LOG(ERROR) << common::SafeCStr(directory) << " is not a directory or not exist!"; | |||||
| return -1; | |||||
| } | |||||
| struct dirent *d_ent = nullptr; | |||||
| char dot[3] = "."; | |||||
| char dotdot[6] = ".."; | |||||
| while ((d_ent = readdir(dir)) != nullptr) { | |||||
| if ((strcmp(d_ent->d_name, dot) != 0) && (strcmp(d_ent->d_name, dotdot) != 0)) { | |||||
| if (d_ent->d_type == DT_DIR) { | |||||
| std::string new_directory = directory + std::string("/") + std::string(d_ent->d_name); | |||||
| if (directory[directory.length() - 1] == '/') { | |||||
| new_directory = directory + string(d_ent->d_name); | |||||
| } | |||||
| if (-1 == GetAbsoluteFiles(new_directory, files_absolute_path)) { | |||||
| closedir(dir); | |||||
| return -1; | |||||
| } | |||||
| } else { | |||||
| std::string absolute_path = directory + std::string("/") + std::string(d_ent->d_name); | |||||
| if (directory[directory.length() - 1] == '/') { | |||||
| absolute_path = directory + std::string(d_ent->d_name); | |||||
| } | |||||
| files_absolute_path.push_back(absolute_path); | |||||
| } | |||||
| } | |||||
| } | |||||
| closedir(dir); | |||||
| return 0; | |||||
| } | |||||
| void ShardWriterImageNet() { | |||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Write imageNet")); | |||||
| // load binary data | |||||
| std::vector<std::vector<uint8_t>> bin_data; | |||||
| std::vector<std::string> filenames; | |||||
| if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { | |||||
| MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; | |||||
| return; | |||||
| } | |||||
| mindrecord::Img2DataUint8(filenames, bin_data); | |||||
| // init shardHeader | |||||
| ShardHeader header_data; | |||||
| MS_LOG(INFO) << "Init ShardHeader Already."; | |||||
| // create schema | |||||
| json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; | |||||
| std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); | |||||
| if (anno_schema == nullptr) { | |||||
| MS_LOG(ERROR) << "Build annotation schema failed"; | |||||
| return; | |||||
| } | |||||
| // add schema to shardHeader | |||||
| int anno_schema_id = header_data.AddSchema(anno_schema); | |||||
| MS_LOG(INFO) << "Init Schema Already."; | |||||
| // create index | |||||
| std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name"); | |||||
| std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label"); | |||||
| std::vector<std::pair<uint64_t, std::string>> fields; | |||||
| fields.push_back(index_field1); | |||||
| fields.push_back(index_field2); | |||||
| // add index to shardHeader | |||||
| header_data.AddIndexFields(fields); | |||||
| MS_LOG(INFO) << "Init Index Fields Already."; | |||||
| // load meta data | |||||
| std::vector<json> annotations; | |||||
| LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 10); | |||||
| // add data | |||||
| std::map<std::uint64_t, std::vector<json>> rawdatas; | |||||
| rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations)); | |||||
| MS_LOG(INFO) << "Init Images Already."; | |||||
| // init file_writer | |||||
| std::vector<std::string> file_names; | |||||
| int file_count = 4; | |||||
| for (int i = 1; i <= file_count; i++) { | |||||
| file_names.emplace_back(std::string("./imagenet.shard0") + std::to_string(i)); | |||||
| MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); | |||||
| } | |||||
| MS_LOG(INFO) << "Init Output Files Already."; | |||||
| { | |||||
| ShardWriter fw_init; | |||||
| fw_init.Open(file_names); | |||||
| // set shardHeader | |||||
| fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||||
| // close file_writer | |||||
| fw_init.Commit(); | |||||
| } | |||||
| std::string filename = "./imagenet.shard01"; | |||||
| { | |||||
| MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; | |||||
| mindrecord::ShardWriter fw; | |||||
| fw.OpenForAppend(filename); | |||||
| fw.WriteRawData(rawdatas, bin_data); | |||||
| fw.Commit(); | |||||
| } | |||||
| mindrecord::ShardIndexGenerator sg{filename}; | |||||
| sg.Build(); | |||||
| sg.WriteToDatabase(); | |||||
| MS_LOG(INFO) << "Done create index"; | |||||
| } | |||||
| void ShardWriterImageNetOneSample() { | |||||
| // load binary data | |||||
| std::vector<std::vector<uint8_t>> bin_data; | |||||
| std::vector<std::string> filenames; | |||||
| if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { | |||||
| MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; | |||||
| return; | |||||
| } | |||||
| mindrecord::Img2DataUint8(filenames, bin_data); | |||||
| // init shardHeader | |||||
| mindrecord::ShardHeader header_data; | |||||
| MS_LOG(INFO) << "Init ShardHeader Already."; | |||||
| // create schema | |||||
| json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; | |||||
| std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); | |||||
| if (anno_schema == nullptr) { | |||||
| MS_LOG(ERROR) << "Build annotation schema failed"; | |||||
| return; | |||||
| } | |||||
| // add schema to shardHeader | |||||
| int anno_schema_id = header_data.AddSchema(anno_schema); | |||||
| MS_LOG(INFO) << "Init Schema Already."; | |||||
| // create index | |||||
| std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name"); | |||||
| std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label"); | |||||
| std::vector<std::pair<uint64_t, std::string>> fields; | |||||
| fields.push_back(index_field1); | |||||
| fields.push_back(index_field2); | |||||
| // add index to shardHeader | |||||
| header_data.AddIndexFields(fields); | |||||
| MS_LOG(INFO) << "Init Index Fields Already."; | |||||
| // load meta data | |||||
| std::vector<json> annotations; | |||||
| LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1); | |||||
| // add data | |||||
| std::map<std::uint64_t, std::vector<json>> rawdatas; | |||||
| rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations)); | |||||
| MS_LOG(INFO) << "Init Images Already."; | |||||
| // init file_writer | |||||
| std::vector<std::string> file_names; | |||||
| for (int i = 1; i <= 4; i++) { | |||||
| file_names.emplace_back(std::string("./OneSample.shard0") + std::to_string(i)); | |||||
| MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); | |||||
| } | |||||
| MS_LOG(INFO) << "Init Output Files Already."; | |||||
| { | |||||
| mindrecord::ShardWriter fw_init; | |||||
| fw_init.Open(file_names); | |||||
| // set shardHeader | |||||
| fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||||
| // close file_writer | |||||
| fw_init.Commit(); | |||||
| } | |||||
| std::string filename = "./OneSample.shard01"; | |||||
| { | |||||
| MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; | |||||
| mindrecord::ShardWriter fw; | |||||
| fw.OpenForAppend(filename); | |||||
| bin_data = std::vector<std::vector<uint8_t>>(bin_data.begin(), bin_data.begin() + 1); | |||||
| fw.WriteRawData(rawdatas, bin_data); | |||||
| fw.Commit(); | |||||
| } | |||||
| mindrecord::ShardIndexGenerator sg{filename}; | |||||
| sg.Build(); | |||||
| sg.WriteToDatabase(); | |||||
| MS_LOG(INFO) << "Done create index"; | |||||
| } | |||||
| void ShardWriterImageNetOpenForAppend(string filename) { | |||||
| for (int i = 1; i <= 4; i++) { | |||||
| string filename = std::string("./OpenForAppendSample.shard0") + std::to_string(i); | |||||
| string db_name = std::string("./OpenForAppendSample.shard0") + std::to_string(i) + ".db"; | |||||
| remove(common::SafeCStr(filename)); | |||||
| remove(common::SafeCStr(db_name)); | |||||
| } | |||||
| // load binary data | |||||
| std::vector<std::vector<uint8_t>> bin_data; | |||||
| std::vector<std::string> filenames; | |||||
| if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { | |||||
| MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; | |||||
| return; | |||||
| } | |||||
| mindrecord::Img2DataUint8(filenames, bin_data); | |||||
| // init shardHeader | |||||
| mindrecord::ShardHeader header_data; | |||||
| MS_LOG(INFO) << "Init ShardHeader Already."; | |||||
| // create schema | |||||
| json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; | |||||
| std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); | |||||
| if (anno_schema == nullptr) { | |||||
| MS_LOG(ERROR) << "Build annotation schema failed"; | |||||
| return; | |||||
| } | |||||
| // add schema to shardHeader | |||||
| int anno_schema_id = header_data.AddSchema(anno_schema); | |||||
| MS_LOG(INFO) << "Init Schema Already."; | |||||
| // create index | |||||
| std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name"); | |||||
| std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label"); | |||||
| std::vector<std::pair<uint64_t, std::string>> fields; | |||||
| fields.push_back(index_field1); | |||||
| fields.push_back(index_field2); | |||||
| // add index to shardHeader | |||||
| header_data.AddIndexFields(fields); | |||||
| MS_LOG(INFO) << "Init Index Fields Already."; | |||||
| // load meta data | |||||
| std::vector<json> annotations; | |||||
| LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1); | |||||
| // add data | |||||
| std::map<std::uint64_t, std::vector<json>> rawdatas; | |||||
| rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations)); | |||||
| MS_LOG(INFO) << "Init Images Already."; | |||||
| // init file_writer | |||||
| std::vector<std::string> file_names; | |||||
| for (int i = 1; i <= 4; i++) { | |||||
| file_names.emplace_back(std::string("./OpenForAppendSample.shard0") + std::to_string(i)); | |||||
| MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); | |||||
| } | |||||
| MS_LOG(INFO) << "Init Output Files Already."; | |||||
| { | |||||
| mindrecord::ShardWriter fw_init; | |||||
| fw_init.Open(file_names); | |||||
| // set shardHeader | |||||
| fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||||
| // close file_writer | |||||
| fw_init.Commit(); | |||||
| } | |||||
| { | |||||
| MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; | |||||
| mindrecord::ShardWriter fw; | |||||
| auto ret = fw.OpenForAppend(filename); | |||||
| if (ret == FAILED) { | |||||
| return; | |||||
| } | |||||
| bin_data = std::vector<std::vector<uint8_t>>(bin_data.begin(), bin_data.begin() + 1); | |||||
| fw.WriteRawData(rawdatas, bin_data); | |||||
| fw.Commit(); | |||||
| } | |||||
| ShardIndexGenerator sg{filename}; | |||||
| sg.Build(); | |||||
| sg.WriteToDatabase(); | |||||
| MS_LOG(INFO) << "Done create index"; | |||||
| } | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,7 @@ | |||||
| #ifndef TESTS_MINDRECORD_UT_UT_COMMON_H_ | #ifndef TESTS_MINDRECORD_UT_UT_COMMON_H_ | ||||
| #define TESTS_MINDRECORD_UT_UT_COMMON_H_ | #define TESTS_MINDRECORD_UT_UT_COMMON_H_ | ||||
| #include <dirent.h> | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -25,7 +26,9 @@ | |||||
| #include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "mindrecord/include/shard_index.h" | #include "mindrecord/include/shard_index.h" | ||||
| #include "mindrecord/include/shard_header.h" | |||||
| #include "mindrecord/include/shard_index_generator.h" | |||||
| #include "mindrecord/include/shard_writer.h" | |||||
| using json = nlohmann::json; | using json = nlohmann::json; | ||||
| using std::ifstream; | using std::ifstream; | ||||
| using std::pair; | using std::pair; | ||||
| @@ -40,11 +43,10 @@ class Common : public testing::Test { | |||||
| std::string install_root; | std::string install_root; | ||||
| // every TEST_F macro will enter one | // every TEST_F macro will enter one | ||||
| void SetUp(); | |||||
| virtual void SetUp(); | |||||
| void TearDown(); | |||||
| virtual void TearDown(); | |||||
| static void LoadData(const std::string &directory, std::vector<json> &json_buffer, const int max_num); | |||||
| }; | }; | ||||
| } // namespace UT | } // namespace UT | ||||
| @@ -55,6 +57,21 @@ class Common : public testing::Test { | |||||
| /// | /// | ||||
| /// return the formatted string | /// return the formatted string | ||||
| const std::string FormatInfo(const std::string &message, uint32_t message_total_length = 128); | const std::string FormatInfo(const std::string &message, uint32_t message_total_length = 128); | ||||
| void LoadData(const std::string &directory, std::vector<json> &json_buffer, const int max_num); | |||||
| void LoadDataFromImageNet(const std::string &directory, std::vector<json> &json_buffer, const int max_num); | |||||
| int Img2DataUint8(const std::vector<std::string> &img_absolute_path, std::vector<std::vector<uint8_t>> &bin_data); | |||||
| int GetAbsoluteFiles(std::string directory, std::vector<std::string> &files_absolute_path); | |||||
| void ShardWriterImageNet(); | |||||
| void ShardWriterImageNetOneSample(); | |||||
| void ShardWriterImageNetOpenForAppend(string filename); | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // TESTS_MINDRECORD_UT_UT_COMMON_H_ | #endif // TESTS_MINDRECORD_UT_UT_COMMON_H_ | ||||
| @@ -29,7 +29,6 @@ | |||||
| #include "mindrecord/include/shard_statistics.h" | #include "mindrecord/include/shard_statistics.h" | ||||
| #include "securec.h" | #include "securec.h" | ||||
| #include "ut_common.h" | #include "ut_common.h" | ||||
| #include "ut_shard_writer_test.h" | |||||
| using mindspore::MsLogLevel::INFO; | using mindspore::MsLogLevel::INFO; | ||||
| using mindspore::ExceptionType::NoExceptionType; | using mindspore::ExceptionType::NoExceptionType; | ||||
| @@ -43,7 +42,7 @@ class TestShard : public UT::Common { | |||||
| }; | }; | ||||
| TEST_F(TestShard, TestShardSchemaPart) { | TEST_F(TestShard, TestShardSchemaPart) { | ||||
| TestShardWriterImageNet(); | |||||
| ShardWriterImageNet(); | |||||
| MS_LOG(INFO) << FormatInfo("Test schema"); | MS_LOG(INFO) << FormatInfo("Test schema"); | ||||
| @@ -55,6 +54,12 @@ TEST_F(TestShard, TestShardSchemaPart) { | |||||
| ASSERT_TRUE(schema != nullptr); | ASSERT_TRUE(schema != nullptr); | ||||
| MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " << | MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " << | ||||
| common::SafeCStr(schema->GetSchema().dump()); | common::SafeCStr(schema->GetSchema().dump()); | ||||
| for (int i = 1; i <= 4; i++) { | |||||
| string filename = std::string("./imagenet.shard0") + std::to_string(i); | |||||
| string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; | |||||
| remove(common::SafeCStr(filename)); | |||||
| remove(common::SafeCStr(db_name)); | |||||
| } | |||||
| } | } | ||||
| TEST_F(TestShard, TestStatisticPart) { | TEST_F(TestShard, TestStatisticPart) { | ||||
| @@ -128,6 +133,5 @@ TEST_F(TestShard, TestShardHeaderPart) { | |||||
| ASSERT_EQ(resFields, fields); | ASSERT_EQ(resFields, fields); | ||||
| } | } | ||||
| TEST_F(TestShard, TestShardWriteImage) { MS_LOG(INFO) << FormatInfo("Test writer"); } | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -53,38 +53,6 @@ class TestShardIndexGenerator : public UT::Common { | |||||
| TestShardIndexGenerator() {} | TestShardIndexGenerator() {} | ||||
| }; | }; | ||||
| /* | |||||
| TEST_F(TestShardIndexGenerator, GetField) { | |||||
| MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field"); | |||||
| int max_num = 1; | |||||
| string input_path1 = install_root + "/test/testCBGData/data/annotation.data"; | |||||
| std::vector<json> json_buffer1; // store the image_raw_meta.data | |||||
| Common::LoadData(input_path1, json_buffer1, max_num); | |||||
| MS_LOG(INFO) << "Fetch fields: "; | |||||
| for (auto &j : json_buffer1) { | |||||
| auto v_name = ShardIndexGenerator::GetField("anno_tool", j); | |||||
| auto v_attr_name = ShardIndexGenerator::GetField("entity_instances.attributes.attr_name", j); | |||||
| auto v_entity_name = ShardIndexGenerator::GetField("entity_instances.entity_name", j); | |||||
| vector<string> names = {"\"CVAT\""}; | |||||
| for (unsigned int i = 0; i != names.size(); i++) { | |||||
| ASSERT_EQ(names[i], v_name[i]); | |||||
| } | |||||
| vector<string> attr_names = {"\"脸部评分\"", "\"特征点\"", "\"points_example\"", "\"polyline_example\"", | |||||
| "\"polyline_example\""}; | |||||
| for (unsigned int i = 0; i != attr_names.size(); i++) { | |||||
| ASSERT_EQ(attr_names[i], v_attr_name[i]); | |||||
| } | |||||
| vector<string> entity_names = {"\"276点人脸\"", "\"points_example\"", "\"polyline_example\"", | |||||
| "\"polyline_example\""}; | |||||
| for (unsigned int i = 0; i != entity_names.size(); i++) { | |||||
| ASSERT_EQ(entity_names[i], v_entity_name[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| */ | |||||
| TEST_F(TestShardIndexGenerator, TakeFieldType) { | TEST_F(TestShardIndexGenerator, TakeFieldType) { | ||||
| MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type"); | MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type"); | ||||
| @@ -40,6 +40,17 @@ namespace mindrecord { | |||||
| class TestShardOperator : public UT::Common { | class TestShardOperator : public UT::Common { | ||||
| public: | public: | ||||
| TestShardOperator() {} | TestShardOperator() {} | ||||
| void SetUp() override { ShardWriterImageNet(); } | |||||
| void TearDown() override { | |||||
| for (int i = 1; i <= 4; i++) { | |||||
| string filename = std::string("./imagenet.shard0") + std::to_string(i); | |||||
| string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; | |||||
| remove(common::SafeCStr(filename)); | |||||
| remove(common::SafeCStr(db_name)); | |||||
| } | |||||
| } | |||||
| }; | }; | ||||
| TEST_F(TestShardOperator, TestShardSampleBasic) { | TEST_F(TestShardOperator, TestShardSampleBasic) { | ||||
| @@ -165,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { | |||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | ||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | |||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | dataset.Finish(); | ||||
| @@ -191,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { | |||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | ||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | |||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | dataset.Finish(); | ||||
| @@ -37,6 +37,16 @@ namespace mindrecord { | |||||
| class TestShardReader : public UT::Common { | class TestShardReader : public UT::Common { | ||||
| public: | public: | ||||
| TestShardReader() {} | TestShardReader() {} | ||||
| void SetUp() override { ShardWriterImageNet(); } | |||||
| void TearDown() override { | |||||
| for (int i = 1; i <= 4; i++) { | |||||
| string filename = std::string("./imagenet.shard0") + std::to_string(i); | |||||
| string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; | |||||
| remove(common::SafeCStr(filename)); | |||||
| remove(common::SafeCStr(db_name)); | |||||
| } | |||||
| } | |||||
| }; | }; | ||||
| TEST_F(TestShardReader, TestShardReaderGeneral) { | TEST_F(TestShardReader, TestShardReaderGeneral) { | ||||
| @@ -51,8 +61,8 @@ TEST_F(TestShardReader, TestShardReaderGeneral) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| for (auto& j : x) { | |||||
| for (auto& item : std::get<1>(j).items()) { | |||||
| for (auto &j : x) { | |||||
| for (auto &item : std::get<1>(j).items()) { | |||||
| MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); | MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -74,8 +84,8 @@ TEST_F(TestShardReader, TestShardReaderSample) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| for (auto& j : x) { | |||||
| for (auto& item : std::get<1>(j).items()) { | |||||
| for (auto &j : x) { | |||||
| for (auto &item : std::get<1>(j).items()) { | |||||
| MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); | MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -99,8 +109,8 @@ TEST_F(TestShardReader, TestShardReaderBlock) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetBlockNext(); | auto x = dataset.GetBlockNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| for (auto& j : x) { | |||||
| for (auto& item : std::get<1>(j).items()) { | |||||
| for (auto &j : x) { | |||||
| for (auto &item : std::get<1>(j).items()) { | |||||
| MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); | MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -119,8 +129,8 @@ TEST_F(TestShardReader, TestShardReaderEasy) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| for (auto& j : x) { | |||||
| for (auto& item : std::get<1>(j).items()) { | |||||
| for (auto &j : x) { | |||||
| for (auto &item : std::get<1>(j).items()) { | |||||
| MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); | MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -140,8 +150,8 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| for (auto& j : x) { | |||||
| for (auto& item : std::get<1>(j).items()) { | |||||
| for (auto &j : x) { | |||||
| for (auto &item : std::get<1>(j).items()) { | |||||
| MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); | MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -169,9 +179,9 @@ TEST_F(TestShardReader, TestShardVersion) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| for (auto& j : x) { | |||||
| for (auto &j : x) { | |||||
| MS_LOG(INFO) << "result size: " << std::get<0>(j).size(); | MS_LOG(INFO) << "result size: " << std::get<0>(j).size(); | ||||
| for (auto& item : std::get<1>(j).items()) { | |||||
| for (auto &item : std::get<1>(j).items()) { | |||||
| MS_LOG(INFO) << "key: " << common::SafeCStr(item.key()) << ", value: " << common::SafeCStr(item.value().dump()); | MS_LOG(INFO) << "key: " << common::SafeCStr(item.key()) << ", value: " << common::SafeCStr(item.value().dump()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -201,8 +211,8 @@ TEST_F(TestShardReader, TestShardReaderConsumer) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| for (auto& j : x) { | |||||
| for (auto& item : std::get<1>(j).items()) { | |||||
| for (auto &j : x) { | |||||
| for (auto &item : std::get<1>(j).items()) { | |||||
| MS_LOG(INFO) << "key: " << common::SafeCStr(item.key()) << ", value: " << common::SafeCStr(item.value().dump()); | MS_LOG(INFO) << "key: " << common::SafeCStr(item.key()) << ", value: " << common::SafeCStr(item.value().dump()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -33,15 +33,25 @@ | |||||
| #include "mindrecord/include/shard_segment.h" | #include "mindrecord/include/shard_segment.h" | ||||
| #include "ut_common.h" | #include "ut_common.h" | ||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | using mindspore::LogStream; | ||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| class TestShardSegment : public UT::Common { | class TestShardSegment : public UT::Common { | ||||
| public: | public: | ||||
| TestShardSegment() {} | TestShardSegment() {} | ||||
| void SetUp() override { ShardWriterImageNet(); } | |||||
| void TearDown() override { | |||||
| for (int i = 1; i <= 4; i++) { | |||||
| string filename = std::string("./imagenet.shard0") + std::to_string(i); | |||||
| string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; | |||||
| remove(common::SafeCStr(filename)); | |||||
| remove(common::SafeCStr(db_name)); | |||||
| } | |||||
| } | |||||
| }; | }; | ||||
| TEST_F(TestShardSegment, TestShardSegment) { | TEST_F(TestShardSegment, TestShardSegment) { | ||||
| @@ -16,7 +16,6 @@ | |||||
| #include <chrono> | #include <chrono> | ||||
| #include <cstring> | #include <cstring> | ||||
| #include <dirent.h> | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| @@ -30,7 +29,6 @@ | |||||
| #include "mindrecord/include/shard_index_generator.h" | #include "mindrecord/include/shard_index_generator.h" | ||||
| #include "securec.h" | #include "securec.h" | ||||
| #include "ut_common.h" | #include "ut_common.h" | ||||
| #include "ut_shard_writer_test.h" | |||||
| using mindspore::LogStream; | using mindspore::LogStream; | ||||
| using mindspore::ExceptionType::NoExceptionType; | using mindspore::ExceptionType::NoExceptionType; | ||||
| @@ -44,249 +42,10 @@ class TestShardWriter : public UT::Common { | |||||
| TestShardWriter() {} | TestShardWriter() {} | ||||
| }; | }; | ||||
| void LoadDataFromImageNet(const std::string &directory, std::vector<json> &json_buffer, const int max_num) { | |||||
| int count = 0; | |||||
| string input_path = directory; | |||||
| ifstream infile(input_path); | |||||
| if (!infile.is_open()) { | |||||
| MS_LOG(ERROR) << "can not open the file "; | |||||
| return; | |||||
| } | |||||
| string temp; | |||||
| string filename; | |||||
| string label; | |||||
| json j; | |||||
| while (getline(infile, temp) && count != max_num) { | |||||
| count++; | |||||
| std::size_t pos = temp.find(",", 0); | |||||
| if (pos != std::string::npos) { | |||||
| j["file_name"] = temp.substr(0, pos); | |||||
| j["label"] = atoi(common::SafeCStr(temp.substr(pos + 1, temp.length()))); | |||||
| json_buffer.push_back(j); | |||||
| } | |||||
| } | |||||
| infile.close(); | |||||
| } | |||||
| int Img2DataUint8(const std::vector<std::string> &img_absolute_path, std::vector<std::vector<uint8_t>> &bin_data) { | |||||
| for (auto &file : img_absolute_path) { | |||||
| // read image file | |||||
| std::ifstream in(common::SafeCStr(file), std::ios::in | std::ios::binary | std::ios::ate); | |||||
| if (!in) { | |||||
| MS_LOG(ERROR) << common::SafeCStr(file) << " is not a directory or not exist!"; | |||||
| return -1; | |||||
| } | |||||
| // get the file size | |||||
| uint64_t size = in.tellg(); | |||||
| in.seekg(0, std::ios::beg); | |||||
| std::vector<uint8_t> file_data(size); | |||||
| in.read(reinterpret_cast<char *>(&file_data[0]), size); | |||||
| in.close(); | |||||
| bin_data.push_back(file_data); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| int GetAbsoluteFiles(std::string directory, std::vector<std::string> &files_absolute_path) { | |||||
| DIR *dir = opendir(common::SafeCStr(directory)); | |||||
| if (dir == nullptr) { | |||||
| MS_LOG(ERROR) << common::SafeCStr(directory) << " is not a directory or not exist!"; | |||||
| return -1; | |||||
| } | |||||
| struct dirent *d_ent = nullptr; | |||||
| char dot[3] = "."; | |||||
| char dotdot[6] = ".."; | |||||
| while ((d_ent = readdir(dir)) != nullptr) { | |||||
| if ((strcmp(d_ent->d_name, dot) != 0) && (strcmp(d_ent->d_name, dotdot) != 0)) { | |||||
| if (d_ent->d_type == DT_DIR) { | |||||
| std::string new_directory = directory + std::string("/") + std::string(d_ent->d_name); | |||||
| if (directory[directory.length() - 1] == '/') { | |||||
| new_directory = directory + string(d_ent->d_name); | |||||
| } | |||||
| if (-1 == GetAbsoluteFiles(new_directory, files_absolute_path)) { | |||||
| closedir(dir); | |||||
| return -1; | |||||
| } | |||||
| } else { | |||||
| std::string absolute_path = directory + std::string("/") + std::string(d_ent->d_name); | |||||
| if (directory[directory.length() - 1] == '/') { | |||||
| absolute_path = directory + std::string(d_ent->d_name); | |||||
| } | |||||
| files_absolute_path.push_back(absolute_path); | |||||
| } | |||||
| } | |||||
| } | |||||
| closedir(dir); | |||||
| return 0; | |||||
| } | |||||
| void TestShardWriterImageNet() { | |||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Write imageNet")); | |||||
| // load binary data | |||||
| std::vector<std::vector<uint8_t>> bin_data; | |||||
| std::vector<std::string> filenames; | |||||
| if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { | |||||
| MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; | |||||
| return; | |||||
| } | |||||
| mindrecord::Img2DataUint8(filenames, bin_data); | |||||
| // init shardHeader | |||||
| mindrecord::ShardHeader header_data; | |||||
| MS_LOG(INFO) << "Init ShardHeader Already."; | |||||
| // create schema | |||||
| json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; | |||||
| std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); | |||||
| if (anno_schema == nullptr) { | |||||
| MS_LOG(ERROR) << "Build annotation schema failed"; | |||||
| return; | |||||
| } | |||||
| // add schema to shardHeader | |||||
| int anno_schema_id = header_data.AddSchema(anno_schema); | |||||
| MS_LOG(INFO) << "Init Schema Already."; | |||||
| // create index | |||||
| std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name"); | |||||
| std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label"); | |||||
| std::vector<std::pair<uint64_t, std::string>> fields; | |||||
| fields.push_back(index_field1); | |||||
| fields.push_back(index_field2); | |||||
| // add index to shardHeader | |||||
| header_data.AddIndexFields(fields); | |||||
| MS_LOG(INFO) << "Init Index Fields Already."; | |||||
| // load meta data | |||||
| std::vector<json> annotations; | |||||
| LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 10); | |||||
| // add data | |||||
| std::map<std::uint64_t, std::vector<json>> rawdatas; | |||||
| rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations)); | |||||
| MS_LOG(INFO) << "Init Images Already."; | |||||
| // init file_writer | |||||
| std::vector<std::string> file_names; | |||||
| int file_count = 4; | |||||
| for (int i = 1; i <= file_count; i++) { | |||||
| file_names.emplace_back(std::string("./imagenet.shard0") + std::to_string(i)); | |||||
| MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); | |||||
| } | |||||
| MS_LOG(INFO) << "Init Output Files Already."; | |||||
| { | |||||
| mindrecord::ShardWriter fw_init; | |||||
| fw_init.Open(file_names); | |||||
| // set shardHeader | |||||
| fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||||
| // close file_writer | |||||
| fw_init.Commit(); | |||||
| } | |||||
| std::string filename = "./imagenet.shard01"; | |||||
| { | |||||
| MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; | |||||
| mindrecord::ShardWriter fw; | |||||
| fw.OpenForAppend(filename); | |||||
| fw.WriteRawData(rawdatas, bin_data); | |||||
| fw.Commit(); | |||||
| } | |||||
| mindrecord::ShardIndexGenerator sg{filename}; | |||||
| sg.Build(); | |||||
| sg.WriteToDatabase(); | |||||
| MS_LOG(INFO) << "Done create index"; | |||||
| } | |||||
| void TestShardWriterImageNetOneSample() { | |||||
| // load binary data | |||||
| std::vector<std::vector<uint8_t>> bin_data; | |||||
| std::vector<std::string> filenames; | |||||
| if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { | |||||
| MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; | |||||
| return; | |||||
| } | |||||
| mindrecord::Img2DataUint8(filenames, bin_data); | |||||
| // init shardHeader | |||||
| mindrecord::ShardHeader header_data; | |||||
| MS_LOG(INFO) << "Init ShardHeader Already."; | |||||
| // create schema | |||||
| json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; | |||||
| std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); | |||||
| if (anno_schema == nullptr) { | |||||
| MS_LOG(ERROR) << "Build annotation schema failed"; | |||||
| return; | |||||
| } | |||||
| // add schema to shardHeader | |||||
| int anno_schema_id = header_data.AddSchema(anno_schema); | |||||
| MS_LOG(INFO) << "Init Schema Already."; | |||||
| // create index | |||||
| std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name"); | |||||
| std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label"); | |||||
| std::vector<std::pair<uint64_t, std::string>> fields; | |||||
| fields.push_back(index_field1); | |||||
| fields.push_back(index_field2); | |||||
| // add index to shardHeader | |||||
| header_data.AddIndexFields(fields); | |||||
| MS_LOG(INFO) << "Init Index Fields Already."; | |||||
| // load meta data | |||||
| std::vector<json> annotations; | |||||
| LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1); | |||||
| // add data | |||||
| std::map<std::uint64_t, std::vector<json>> rawdatas; | |||||
| rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations)); | |||||
| MS_LOG(INFO) << "Init Images Already."; | |||||
| // init file_writer | |||||
| std::vector<std::string> file_names; | |||||
| for (int i = 1; i <= 4; i++) { | |||||
| file_names.emplace_back(std::string("./OneSample.shard0") + std::to_string(i)); | |||||
| MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); | |||||
| } | |||||
| MS_LOG(INFO) << "Init Output Files Already."; | |||||
| { | |||||
| mindrecord::ShardWriter fw_init; | |||||
| fw_init.Open(file_names); | |||||
| // set shardHeader | |||||
| fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||||
| // close file_writer | |||||
| fw_init.Commit(); | |||||
| } | |||||
| std::string filename = "./OneSample.shard01"; | |||||
| { | |||||
| MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; | |||||
| mindrecord::ShardWriter fw; | |||||
| fw.OpenForAppend(filename); | |||||
| bin_data = std::vector<std::vector<uint8_t>>(bin_data.begin(), bin_data.begin() + 1); | |||||
| fw.WriteRawData(rawdatas, bin_data); | |||||
| fw.Commit(); | |||||
| } | |||||
| mindrecord::ShardIndexGenerator sg{filename}; | |||||
| sg.Build(); | |||||
| sg.WriteToDatabase(); | |||||
| MS_LOG(INFO) << "Done create index"; | |||||
| } | |||||
| TEST_F(TestShardWriter, TestShardWriterBench) { | TEST_F(TestShardWriter, TestShardWriterBench) { | ||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test write imageNet")); | MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test write imageNet")); | ||||
| TestShardWriterImageNet(); | |||||
| ShardWriterImageNet(); | |||||
| for (int i = 1; i <= 4; i++) { | for (int i = 1; i <= 4; i++) { | ||||
| string filename = std::string("./imagenet.shard0") + std::to_string(i); | string filename = std::string("./imagenet.shard0") + std::to_string(i); | ||||
| string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; | string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; | ||||
| @@ -297,7 +56,7 @@ TEST_F(TestShardWriter, TestShardWriterBench) { | |||||
| TEST_F(TestShardWriter, TestShardWriterOneSample) { | TEST_F(TestShardWriter, TestShardWriterOneSample) { | ||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test write imageNet int32 of sample less than num of shards")); | MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test write imageNet int32 of sample less than num of shards")); | ||||
| TestShardWriterImageNetOneSample(); | |||||
| ShardWriterImageNetOneSample(); | |||||
| std::string filename = "./OneSample.shard01"; | std::string filename = "./OneSample.shard01"; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| @@ -342,7 +101,7 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) { | |||||
| std::vector<std::string> image_filenames; // save all files' path within path_dir | std::vector<std::string> image_filenames; // save all files' path within path_dir | ||||
| // read image_raw_meta.data | // read image_raw_meta.data | ||||
| Common::LoadData(input_path1, json_buffer1, kMaxNum); | |||||
| LoadData(input_path1, json_buffer1, kMaxNum); | |||||
| MS_LOG(INFO) << "Load Meta Data Already."; | MS_LOG(INFO) << "Load Meta Data Already."; | ||||
| // get files' pathes stored in vector<string> image_filenames | // get files' pathes stored in vector<string> image_filenames | ||||
| @@ -375,7 +134,7 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) { | |||||
| MS_LOG(INFO) << "Init Schema Already."; | MS_LOG(INFO) << "Init Schema Already."; | ||||
| // create/init statistics | // create/init statistics | ||||
| Common::LoadData(input_path3, json_buffer4, 2); | |||||
| LoadData(input_path3, json_buffer4, 2); | |||||
| json static1_json = json_buffer4[0]; | json static1_json = json_buffer4[0]; | ||||
| json static2_json = json_buffer4[1]; | json static2_json = json_buffer4[1]; | ||||
| MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump()); | MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump()); | ||||
| @@ -474,7 +233,7 @@ TEST_F(TestShardWriter, TestShardWriterTrial) { | |||||
| std::vector<std::string> image_filenames; // save all files' path within path_dir | std::vector<std::string> image_filenames; // save all files' path within path_dir | ||||
| // read image_raw_meta.data | // read image_raw_meta.data | ||||
| Common::LoadData(input_path1, json_buffer1, kMaxNum); | |||||
| LoadData(input_path1, json_buffer1, kMaxNum); | |||||
| MS_LOG(INFO) << "Load Meta Data Already."; | MS_LOG(INFO) << "Load Meta Data Already."; | ||||
| // get files' pathes stored in vector<string> image_filenames | // get files' pathes stored in vector<string> image_filenames | ||||
| @@ -508,7 +267,7 @@ TEST_F(TestShardWriter, TestShardWriterTrial) { | |||||
| MS_LOG(INFO) << "Init Schema Already."; | MS_LOG(INFO) << "Init Schema Already."; | ||||
| // create/init statistics | // create/init statistics | ||||
| Common::LoadData(input_path3, json_buffer4, 2); | |||||
| LoadData(input_path3, json_buffer4, 2); | |||||
| json static1_json = json_buffer4[0]; | json static1_json = json_buffer4[0]; | ||||
| json static2_json = json_buffer4[1]; | json static2_json = json_buffer4[1]; | ||||
| MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump()); | MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump()); | ||||
| @@ -613,7 +372,7 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) { | |||||
| std::vector<std::string> image_filenames; // save all files' path within path_dir | std::vector<std::string> image_filenames; // save all files' path within path_dir | ||||
| // read image_raw_meta.data | // read image_raw_meta.data | ||||
| Common::LoadData(input_path1, json_buffer1, kMaxNum); | |||||
| LoadData(input_path1, json_buffer1, kMaxNum); | |||||
| MS_LOG(INFO) << "Load Meta Data Already."; | MS_LOG(INFO) << "Load Meta Data Already."; | ||||
| // get files' pathes stored in vector<string> image_filenames | // get files' pathes stored in vector<string> image_filenames | ||||
| @@ -644,7 +403,7 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) { | |||||
| MS_LOG(INFO) << "Init Schema Already."; | MS_LOG(INFO) << "Init Schema Already."; | ||||
| // create/init statistics | // create/init statistics | ||||
| Common::LoadData(input_path3, json_buffer4, 2); | |||||
| LoadData(input_path3, json_buffer4, 2); | |||||
| json static1_json = json_buffer4[0]; | json static1_json = json_buffer4[0]; | ||||
| json static2_json = json_buffer4[1]; | json static2_json = json_buffer4[1]; | ||||
| MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump()); | MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump()); | ||||
| @@ -1357,107 +1116,24 @@ TEST_F(TestShardWriter, TestWriteOpenFileName) { | |||||
| } | } | ||||
| } | } | ||||
| void TestShardWriterImageNetOpenForAppend(string filename) { | |||||
| for (int i = 1; i <= 4; i++) { | |||||
| string filename = std::string("./OpenForAppendSample.shard0") + std::to_string(i); | |||||
| string db_name = std::string("./OpenForAppendSample.shard0") + std::to_string(i) + ".db"; | |||||
| remove(common::SafeCStr(filename)); | |||||
| remove(common::SafeCStr(db_name)); | |||||
| } | |||||
| // load binary data | |||||
| std::vector<std::vector<uint8_t>> bin_data; | |||||
| std::vector<std::string> filenames; | |||||
| if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { | |||||
| MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; | |||||
| return; | |||||
| } | |||||
| mindrecord::Img2DataUint8(filenames, bin_data); | |||||
| // init shardHeader | |||||
| mindrecord::ShardHeader header_data; | |||||
| MS_LOG(INFO) << "Init ShardHeader Already."; | |||||
| // create schema | |||||
| json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; | |||||
| std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); | |||||
| if (anno_schema == nullptr) { | |||||
| MS_LOG(ERROR) << "Build annotation schema failed"; | |||||
| return; | |||||
| } | |||||
| // add schema to shardHeader | |||||
| int anno_schema_id = header_data.AddSchema(anno_schema); | |||||
| MS_LOG(INFO) << "Init Schema Already."; | |||||
| // create index | |||||
| std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name"); | |||||
| std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label"); | |||||
| std::vector<std::pair<uint64_t, std::string>> fields; | |||||
| fields.push_back(index_field1); | |||||
| fields.push_back(index_field2); | |||||
| // add index to shardHeader | |||||
| header_data.AddIndexFields(fields); | |||||
| MS_LOG(INFO) << "Init Index Fields Already."; | |||||
| // load meta data | |||||
| std::vector<json> annotations; | |||||
| LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1); | |||||
| // add data | |||||
| std::map<std::uint64_t, std::vector<json>> rawdatas; | |||||
| rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations)); | |||||
| MS_LOG(INFO) << "Init Images Already."; | |||||
| // init file_writer | |||||
| std::vector<std::string> file_names; | |||||
| for (int i = 1; i <= 4; i++) { | |||||
| file_names.emplace_back(std::string("./OpenForAppendSample.shard0") + std::to_string(i)); | |||||
| MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); | |||||
| } | |||||
| MS_LOG(INFO) << "Init Output Files Already."; | |||||
| { | |||||
| mindrecord::ShardWriter fw_init; | |||||
| fw_init.Open(file_names); | |||||
| // set shardHeader | |||||
| fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||||
| // close file_writer | |||||
| fw_init.Commit(); | |||||
| } | |||||
| { | |||||
| MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; | |||||
| mindrecord::ShardWriter fw; | |||||
| auto ret = fw.OpenForAppend(filename); | |||||
| if (ret == FAILED) { | |||||
| return; | |||||
| } | |||||
| bin_data = std::vector<std::vector<uint8_t>>(bin_data.begin(), bin_data.begin() + 1); | |||||
| fw.WriteRawData(rawdatas, bin_data); | |||||
| fw.Commit(); | |||||
| } | |||||
| mindrecord::ShardIndexGenerator sg{filename}; | |||||
| sg.Build(); | |||||
| sg.WriteToDatabase(); | |||||
| MS_LOG(INFO) << "Done create index"; | |||||
| } | |||||
| TEST_F(TestShardWriter, TestOpenForAppend) { | TEST_F(TestShardWriter, TestOpenForAppend) { | ||||
| MS_LOG(INFO) << "start ---- TestOpenForAppend\n"; | MS_LOG(INFO) << "start ---- TestOpenForAppend\n"; | ||||
| string filename = "./"; | string filename = "./"; | ||||
| TestShardWriterImageNetOpenForAppend(filename); | |||||
| ShardWriterImageNetOpenForAppend(filename); | |||||
| string filename1 = "./▒AppendSample.shard01"; | string filename1 = "./▒AppendSample.shard01"; | ||||
| TestShardWriterImageNetOpenForAppend(filename1); | |||||
| ShardWriterImageNetOpenForAppend(filename1); | |||||
| string filename2 = "./ä\xA9ü"; | string filename2 = "./ä\xA9ü"; | ||||
| TestShardWriterImageNetOpenForAppend(filename2); | |||||
| ShardWriterImageNetOpenForAppend(filename2); | |||||
| MS_LOG(INFO) << "end ---- TestOpenForAppend\n"; | MS_LOG(INFO) << "end ---- TestOpenForAppend\n"; | ||||
| for (int i = 1; i <= 4; i++) { | |||||
| string filename = std::string("./OpenForAppendSample.shard0") + std::to_string(i); | |||||
| string db_name = std::string("./OpenForAppendSample.shard0") + std::to_string(i) + ".db"; | |||||
| remove(common::SafeCStr(filename)); | |||||
| remove(common::SafeCStr(db_name)); | |||||
| } | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| @@ -1,26 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef TESTS_MINDRECORD_UT_SHARDWRITER_H | |||||
| #define TESTS_MINDRECORD_UT_SHARDWRITER_H | |||||
| namespace mindspore { | |||||
| namespace mindrecord { | |||||
| void TestShardWriterImageNet(); | |||||
| } // namespace mindrecord | |||||
| } // namespace mindspore | |||||
| #endif // TESTS_MINDRECORD_UT_SHARDWRITER_H | |||||