|
|
|
@@ -13,6 +13,7 @@ |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
#include <thread> |
|
|
|
|
|
|
|
#include "mindrecord/include/shard_index_generator.h" |
|
|
|
#include "common/utils.h" |
|
|
|
@@ -26,7 +27,13 @@ using mindspore::MsLogLevel::INFO; |
|
|
|
namespace mindspore { |
|
|
|
namespace mindrecord { |
|
|
|
ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append) |
|
|
|
: file_path_(file_path), append_(append), page_size_(0), header_size_(0), schema_count_(0) {} |
|
|
|
: file_path_(file_path), |
|
|
|
append_(append), |
|
|
|
page_size_(0), |
|
|
|
header_size_(0), |
|
|
|
schema_count_(0), |
|
|
|
task_(0), |
|
|
|
write_success_(true) {} |
|
|
|
|
|
|
|
MSRStatus ShardIndexGenerator::Build() { |
|
|
|
ShardHeader header = ShardHeader(); |
|
|
|
@@ -284,7 +291,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateRawSQL( |
|
|
|
return {SUCCESS, sql}; |
|
|
|
} |
|
|
|
|
|
|
|
MSRStatus ShardIndexGenerator::BindParamaterExecuteSQL( |
|
|
|
MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( |
|
|
|
sqlite3 *db, const std::string &sql, |
|
|
|
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) { |
|
|
|
sqlite3_stmt *stmt = nullptr; |
|
|
|
@@ -471,9 +478,9 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s |
|
|
|
return {SUCCESS, std::move(fields)}; |
|
|
|
} |
|
|
|
|
|
|
|
MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db, |
|
|
|
const std::vector<int> &raw_page_ids, |
|
|
|
const std::map<int, int> &blob_id_to_page_id) { |
|
|
|
MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db, |
|
|
|
const std::vector<int> &raw_page_ids, |
|
|
|
const std::map<int, int> &blob_id_to_page_id) { |
|
|
|
// Add index data to database |
|
|
|
std::string shard_address = shard_header_.get_shard_address_by_id(shard_no); |
|
|
|
if (shard_address.empty()) { |
|
|
|
@@ -493,7 +500,7 @@ MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std: |
|
|
|
if (data.first != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (BindParamaterExecuteSQL(db.second, sql.second, data.second) == FAILED) { |
|
|
|
if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; |
|
|
|
@@ -514,37 +521,62 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { |
|
|
|
page_size_ = shard_header_.get_page_size(); |
|
|
|
header_size_ = shard_header_.get_header_size(); |
|
|
|
schema_count_ = shard_header_.get_schema_count(); |
|
|
|
if (shard_header_.get_shard_count() <= kMaxShardCount) { |
|
|
|
// Create one database per shard |
|
|
|
for (int shard_no = 0; shard_no < shard_header_.get_shard_count(); ++shard_no) { |
|
|
|
// Create database |
|
|
|
auto db = CreateDatabase(shard_no); |
|
|
|
if (db.first != SUCCESS || db.second == nullptr) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; |
|
|
|
|
|
|
|
// Pre-processing page information |
|
|
|
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; |
|
|
|
|
|
|
|
std::map<int, int> blob_id_to_page_id; |
|
|
|
std::vector<int> raw_page_ids; |
|
|
|
for (uint64_t i = 0; i < total_pages; ++i) { |
|
|
|
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first; |
|
|
|
if (cur_page->get_page_type() == "RAW_DATA") { |
|
|
|
raw_page_ids.push_back(i); |
|
|
|
} else if (cur_page->get_page_type() == "BLOB_DATA") { |
|
|
|
blob_id_to_page_id[cur_page->get_page_type_id()] = i; |
|
|
|
} |
|
|
|
} |
|
|
|
if (shard_header_.get_shard_count() > kMaxShardCount) { |
|
|
|
MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
task_ = 0; // set two atomic vars to initial value |
|
|
|
write_success_ = true; |
|
|
|
|
|
|
|
if (ExcuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
// spawn half the physical threads or total number of shards whichever is smaller |
|
|
|
const unsigned int num_workers = |
|
|
|
std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.get_shard_count())); |
|
|
|
|
|
|
|
std::vector<std::thread> threads; |
|
|
|
threads.reserve(num_workers); |
|
|
|
|
|
|
|
for (size_t t = 0; t < threads.capacity(); t++) { |
|
|
|
threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this)); |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t t = 0; t < threads.capacity(); t++) { |
|
|
|
threads[t].join(); |
|
|
|
} |
|
|
|
return write_success_ ? SUCCESS : FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
void ShardIndexGenerator::DatabaseWriter() { |
|
|
|
int shard_no = task_++; |
|
|
|
while (shard_no < shard_header_.get_shard_count()) { |
|
|
|
auto db = CreateDatabase(shard_no); |
|
|
|
if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { |
|
|
|
write_success_ = false; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; |
|
|
|
|
|
|
|
// Pre-processing page information |
|
|
|
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; |
|
|
|
|
|
|
|
std::map<int, int> blob_id_to_page_id; |
|
|
|
std::vector<int> raw_page_ids; |
|
|
|
for (uint64_t i = 0; i < total_pages; ++i) { |
|
|
|
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first; |
|
|
|
if (cur_page->get_page_type() == "RAW_DATA") { |
|
|
|
raw_page_ids.push_back(i); |
|
|
|
} else if (cur_page->get_page_type() == "BLOB_DATA") { |
|
|
|
blob_id_to_page_id[cur_page->get_page_type_id()] = i; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully."; |
|
|
|
} |
|
|
|
|
|
|
|
if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { |
|
|
|
write_success_ = false; |
|
|
|
return; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully."; |
|
|
|
shard_no = task_++; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} // namespace mindrecord |
|
|
|
} // namespace mindspore |